1use crate::response::{
2 ChatCompletion, ChatCompletionStream, JSONChoiceStream, Message, ModelType, TextChoiceStream,
3};
4use anyhow::{anyhow, Ok, Result};
5use schemars::schema::SchemaObject;
6use serde::{de::DeserializeOwned, ser::SerializeStruct, Deserialize, Serialize, Serializer};
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct FrequencyPenalty(pub f32);
11
12impl FrequencyPenalty {
13 pub fn new(v: f32) -> Result<Self> {
23 if !(-2.0..=2.0).contains(&v) {
24 return Err(anyhow!(
25 "Frequency penalty value must be between -2 and 2.".to_string()
26 ));
27 }
28 Ok(FrequencyPenalty(v))
29 }
30}
31
32impl Default for FrequencyPenalty {
33 fn default() -> Self {
35 FrequencyPenalty(0.0)
36 }
37}
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct PresencePenalty(pub f32);
42
43impl PresencePenalty {
44 pub fn new(v: f32) -> Result<Self> {
54 if !(-2.0..=2.0).contains(&v) {
55 return Err(anyhow!(
56 "Presence penalty value must be between -2 and 2.".to_string()
57 ));
58 }
59 Ok(PresencePenalty(v))
60 }
61}
62
63impl Default for PresencePenalty {
64 fn default() -> Self {
66 PresencePenalty(0.0)
67 }
68}
69
70#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
72pub enum ResponseType {
73 #[serde(rename = "json_object")]
74 Json,
75 #[serde(rename = "text")]
76 Text,
77}
78
79#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
81pub struct ResponseFormat {
82 #[serde(rename = "type")]
83 pub resp_type: ResponseType,
84}
85
86impl ResponseFormat {
87 pub fn new(rt: ResponseType) -> Self {
93 ResponseFormat { resp_type: rt }
94 }
95}
96
97#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99pub struct MaxToken(pub u32);
100
101impl MaxToken {
102 pub fn new(v: u32) -> Result<Self> {
112 if !(1..=8192).contains(&v) {
113 return Err(anyhow!("Max token must be between 1 and 8192.".to_string()));
114 }
115 Ok(MaxToken(v))
116 }
117}
118
119impl Default for MaxToken {
120 fn default() -> Self {
122 MaxToken(4096)
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
128pub enum Stop {
129 Single(String),
130 Multiple(Vec<String>),
131}
132
133#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
135pub struct StreamOptions {
136 pub include_usage: bool,
137}
138
139impl StreamOptions {
140 pub fn new(include_usage: bool) -> Self {
146 StreamOptions { include_usage }
147 }
148}
149
150#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
152pub struct Temperature(pub u32);
153
154impl Temperature {
155 pub fn new(v: u32) -> Result<Self> {
165 if v > 2 {
166 return Err(anyhow!("Temperature must be between 0 and 2.".to_string()));
167 }
168 Ok(Temperature(v))
169 }
170}
171
172impl Default for Temperature {
173 fn default() -> Self {
175 Temperature(1)
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
181pub struct TopP(pub f32);
182
183impl TopP {
184 pub fn new(v: f32) -> Result<Self> {
194 if !(0.0..=1.0).contains(&v) {
195 return Err(anyhow!("TopP value must be between 0and 2.".to_string()));
196 }
197 Ok(TopP(v))
198 }
199}
200
201impl Default for TopP {
202 fn default() -> Self {
204 TopP(1.0)
205 }
206}
207
208#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
210pub enum ToolType {
211 #[serde(rename = "function")]
212 Function,
213}
214
215#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct Function {
218 pub description: String,
219 pub name: String,
220 pub parameters: SchemaObject,
221}
222
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
225pub struct ToolObject {
226 #[serde(rename = "type")]
227 pub tool_type: ToolType,
228 pub function: Function,
229}
230
231#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub enum ChatCompletionToolChoice {
234 #[serde(rename = "none")]
235 None,
236 #[serde(rename = "auto")]
237 Auto,
238 #[serde(rename = "required")]
239 Required,
240}
241
242#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
244pub struct FunctionChoice {
245 pub name: String,
246}
247
248#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
250pub struct ChatCompletionNamedToolChoice {
251 #[serde(rename = "type")]
252 pub tool_type: ToolType,
253 pub function: FunctionChoice,
254}
255
256#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
258pub enum ToolChoice {
259 ChatCompletion(ChatCompletionToolChoice),
260 ChatCompletionNamed(ChatCompletionNamedToolChoice),
261}
262
263#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
265pub struct TopLogprobs(pub u32);
266
267impl TopLogprobs {
268 pub fn new(v: u32) -> Result<Self> {
278 if v > 20 {
279 return Err(anyhow!(
280 "Top log probs must be between 0 and 20.".to_string()
281 ));
282 }
283 Ok(TopLogprobs(v))
284 }
285}
286
287impl Default for TopLogprobs {
288 fn default() -> Self {
290 TopLogprobs(0)
291 }
292}
293
294#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
296#[serde(tag = "role")]
297pub enum MessageRequest {
298 #[serde(rename = "system")]
299 System(SystemMessageRequest),
300 #[serde(rename = "user")]
301 User(UserMessageRequest),
302 #[serde(rename = "assistant")]
303 Assistant(AssistantMessageRequest),
304 #[serde(rename = "tool")]
305 Tool(ToolMessageRequest),
306}
307
308impl MessageRequest {
309 pub fn from_message(resp_message: &Message) -> Result<Self> {
319 match resp_message.role.as_str() {
320 "system" => Ok(MessageRequest::System(SystemMessageRequest {
321 content: resp_message.content.clone(),
322 name: None,
323 })),
324 "user" => Ok(MessageRequest::User(UserMessageRequest {
325 content: resp_message.content.clone(),
326 name: None,
327 })),
328 "assistant" => {
329 let request = match resp_message.reasoning_content.clone() {
330 Some(reasoning_content) => {
331 AssistantMessageRequest::new(resp_message.content.as_str())
332 .set_reasoning_content(reasoning_content.as_str())
333 }
334 None => AssistantMessageRequest::new(resp_message.content.as_str()),
335 };
336 Ok(MessageRequest::Assistant(request))
337 }
338 "tool" => Ok(MessageRequest::Tool(ToolMessageRequest {
339 content: resp_message.content.clone(),
340 tool_call_id: "".to_string(), })),
342 _ => Err(anyhow!("Invalid message role.".to_string())),
343 }
344 }
345
346 pub fn get_content(&self) -> String {
347 match self {
348 MessageRequest::System(req) => req.content.clone(),
349 MessageRequest::User(req) => req.content.clone(),
350 MessageRequest::Assistant(req) => req.content.clone(),
351 MessageRequest::Tool(req) => req.content.clone(),
352 }
353 }
354}
355
356#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
358pub struct SystemMessageRequest {
359 pub content: String,
360 pub name: Option<String>,
361}
362
363impl SystemMessageRequest {
364 pub fn new(msg: &str) -> Self {
370 SystemMessageRequest {
371 content: msg.to_string(),
372 name: None,
373 }
374 }
375
376 pub fn new_with_name(name: &str, msg: &str) -> Self {
383 SystemMessageRequest {
384 content: msg.to_string(),
385 name: Some(name.to_string()),
386 }
387 }
388}
389
390#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
392pub struct UserMessageRequest {
393 pub content: String,
394 pub name: Option<String>,
395}
396
397impl UserMessageRequest {
398 pub fn new(msg: &str) -> Self {
404 UserMessageRequest {
405 content: msg.to_string(),
406 name: None,
407 }
408 }
409
410 pub fn new_with_name(name: &str, msg: &str) -> Self {
417 UserMessageRequest {
418 content: msg.to_string(),
419 name: Some(name.to_string()),
420 }
421 }
422}
423
424#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
426pub struct AssistantMessageRequest {
427 pub content: String,
428 pub name: Option<String>,
429 pub prefix: bool,
430 pub reasoning_content: Option<String>,
431}
432
433impl AssistantMessageRequest {
434 pub fn new(msg: &str) -> Self {
440 AssistantMessageRequest {
441 content: msg.to_string(),
442 name: None,
443 prefix: false,
444 reasoning_content: None,
445 }
446 }
447
448 pub fn new_with_name(name: &str, msg: &str) -> Self {
455 AssistantMessageRequest {
456 content: msg.to_string(),
457 name: Some(name.to_string()),
458 prefix: false,
459 reasoning_content: None,
460 }
461 }
462
463 pub fn set_reasoning_content(mut self, content: &str) -> Self {
473 self.prefix = true;
474 self.reasoning_content = Some(content.to_string());
475 self
476 }
477
478 pub fn set_prefix(mut self, content: &str) -> Self {
479 self.prefix = true;
480 self.content = content.to_string();
481 self
482 }
483}
484
485#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
487pub struct ToolMessageRequest {
488 pub content: String,
489 pub tool_call_id: String,
490}
491
492impl ToolMessageRequest {
493 pub fn new(msg: &str, tool_call_id: &str) -> Self {
500 ToolMessageRequest {
501 content: msg.to_string(),
502 tool_call_id: tool_call_id.to_string(),
503 }
504 }
505}
506
507pub trait RequestBuilder {
508 type Request: Serialize;
509 type Response: DeserializeOwned;
510 type Item: DeserializeOwned + Send + 'static;
511
512 fn is_beta(&self) -> bool;
513 fn is_stream(&self) -> bool;
514 fn build(self) -> Self::Request;
515}
516
517#[derive(Debug, Default, Clone, PartialEq, Deserialize)]
519pub struct CompletionsRequest {
520 pub messages: Vec<MessageRequest>,
521 pub model: ModelType,
522 pub prompt: String,
523 #[serde(skip_serializing_if = "Option::is_none")]
524 pub max_tokens: Option<MaxToken>,
525 #[serde(skip_serializing_if = "Option::is_none")]
526 pub response_format: Option<ResponseFormat>,
527 #[serde(skip_serializing_if = "Option::is_none")]
528 pub stop: Option<Stop>,
529 pub stream: bool,
530 #[serde(skip_serializing_if = "Option::is_none")]
531 pub stream_options: Option<StreamOptions>,
532 #[serde(skip_serializing_if = "Option::is_none")]
533 pub tools: Option<Vec<ToolObject>>,
534 #[serde(skip_serializing_if = "Option::is_none")]
535 pub tool_choice: Option<ToolChoice>,
536
537 #[serde(skip_serializing_if = "Option::is_none")]
539 pub temperature: Option<Temperature>,
540 #[serde(skip_serializing_if = "Option::is_none")]
541 pub top_p: Option<TopP>,
542 #[serde(skip_serializing_if = "Option::is_none")]
543 pub presence_penalty: Option<PresencePenalty>,
544 #[serde(skip_serializing_if = "Option::is_none")]
545 pub frequency_penalty: Option<FrequencyPenalty>,
546 #[serde(skip_serializing_if = "Option::is_none")]
547 pub logprobs: Option<bool>,
548 #[serde(skip_serializing_if = "Option::is_none")]
549 pub top_logprobs: Option<TopLogprobs>,
550}
551
552impl Serialize for CompletionsRequest {
553 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
554 where
555 S: Serializer,
556 {
557 let mut state = serializer.serialize_struct("CompletionsRequest", 12)?;
558
559 state.serialize_field("messages", &self.messages)?;
560 state.serialize_field("model", &self.model)?;
561 state.serialize_field("max_tokens", &self.max_tokens)?;
562 state.serialize_field("response_format", &self.response_format)?;
563 state.serialize_field("stop", &self.stop)?;
564 state.serialize_field("stream", &self.stream)?;
565 state.serialize_field("stream_options", &self.stream_options)?;
566 state.serialize_field("tools", &self.tools)?;
567 state.serialize_field("tool_choice", &self.tool_choice)?;
568 state.serialize_field("prompt", &self.prompt)?;
569
570 if self.model != ModelType::DeepSeekReasoner {
572 state.serialize_field("temperature", &self.temperature)?;
573 state.serialize_field("top_p", &self.top_p)?;
574 state.serialize_field("presence_penalty", &self.presence_penalty)?;
575 state.serialize_field("frequency_penalty", &self.frequency_penalty)?;
576 state.serialize_field("logprobs", &self.logprobs)?;
577 state.serialize_field("top_logprobs", &self.top_logprobs)?;
578 }
579
580 state.end()
581 }
582}
583
584#[derive(Debug, Default)]
585pub struct CompletionsRequestBuilder {
586 beta: bool,
587 messages: Vec<MessageRequest>,
588 model: ModelType,
589
590 stream: bool,
591 stream_options: Option<StreamOptions>,
592
593 max_tokens: Option<MaxToken>,
594 response_format: Option<ResponseFormat>,
595 stop: Option<Stop>,
596 tools: Option<Vec<ToolObject>>,
597 tool_choice: Option<ToolChoice>,
598 prompt: String,
599 temperature: Option<Temperature>,
600 top_p: Option<TopP>,
601 presence_penalty: Option<PresencePenalty>,
602 frequency_penalty: Option<FrequencyPenalty>,
603 logprobs: Option<bool>,
604 top_logprobs: Option<TopLogprobs>,
605}
606
607impl CompletionsRequestBuilder {
608 pub fn new(messages: Vec<MessageRequest>) -> Self {
609 Self {
610 messages,
611 model: ModelType::DeepSeekChat,
612 prompt: String::new(),
613 ..Default::default()
614 }
615 }
616 pub fn use_model(mut self, model: ModelType) -> Self {
617 self.model = model;
618 self
619 }
620
621 pub fn append_fim_message(self, _prompt: &str, _suffix: &str) -> Self {
623 todo!("Not enough detail in document")
624 }
625
626 pub fn append_prefix_message(mut self, msg: &str) -> Self {
628 self.messages.push(MessageRequest::Assistant(
629 AssistantMessageRequest::new(msg).set_prefix(msg),
630 ));
631 self
632 }
633
634 pub fn append_user_message(mut self, msg: &str) -> Self {
635 self.messages
636 .push(MessageRequest::User(UserMessageRequest::new(msg)));
637 self
638 }
639
640 pub fn max_tokens(mut self, value: u32) -> Result<Self> {
641 self.max_tokens = Some(MaxToken::new(value)?);
642 Ok(self)
643 }
644
645 pub fn use_beta(mut self, value: bool) -> Self {
646 self.beta = value;
647 self
648 }
649
650 pub fn stream(mut self, value: bool) -> Self {
651 self.stream = value;
652 self
653 }
654
655 pub fn stream_options(mut self, value: StreamOptions) -> Self {
656 self.stream_options = Some(value);
657 self
658 }
659
660 pub fn response_format(mut self, value: ResponseType) -> Self {
661 self.response_format = Some(ResponseFormat { resp_type: value });
662 self
663 }
664
665 pub fn stop(mut self, value: Stop) -> Self {
666 self.stop = Some(value);
667 self
668 }
669
670 pub fn tools(mut self, value: Vec<ToolObject>) -> Self {
671 self.tools = Some(value);
672 self
673 }
674
675 pub fn tool_choice(mut self, value: ToolChoice) -> Self {
676 self.tool_choice = Some(value);
677 self
678 }
679
680 pub fn prompt(mut self, value: String) -> Self {
681 self.prompt = value;
682 self
683 }
684
685 pub fn temperature(mut self, value: u32) -> Result<Self> {
686 self.temperature = Some(Temperature::new(value)?);
687 Ok(self)
688 }
689
690 pub fn top_p(mut self, value: f32) -> Result<Self> {
691 self.top_p = Some(TopP::new(value)?);
692 Ok(self)
693 }
694
695 pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
696 self.presence_penalty = Some(PresencePenalty::new(value)?);
697 Ok(self)
698 }
699
700 pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
701 self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
702 Ok(self)
703 }
704
705 pub fn logprobs(mut self, value: bool) -> Self {
706 self.logprobs = Some(value);
707 self
708 }
709
710 pub fn top_logprobs(mut self, value: u32) -> Result<Self> {
711 self.top_logprobs = Some(TopLogprobs::new(value)?);
712 Ok(self)
713 }
714}
715
716impl RequestBuilder for CompletionsRequestBuilder {
717 type Request = CompletionsRequest;
718 type Response = ChatCompletion;
719 type Item = ChatCompletionStream<JSONChoiceStream>;
720
721 fn is_beta(&self) -> bool {
722 self.beta
723 }
724
725 fn is_stream(&self) -> bool {
726 self.stream
727 }
728
729 fn build(self) -> CompletionsRequest {
730 CompletionsRequest {
731 messages: self.messages,
732 model: self.model,
733 max_tokens: self.max_tokens,
734 response_format: self.response_format,
735 stop: self.stop,
736 stream: self.stream,
737 stream_options: self.stream_options,
738 tools: self.tools,
739 tool_choice: self.tool_choice,
740 prompt: self.prompt,
741 temperature: self.temperature,
742 top_p: self.top_p,
743 presence_penalty: self.presence_penalty,
744 frequency_penalty: self.frequency_penalty,
745 logprobs: self.logprobs,
746 top_logprobs: self.top_logprobs,
747 }
748 }
749}
750
751#[derive(Debug, Default, Clone, PartialEq, Serialize)]
753pub struct FMICompletionsRequest {
754 pub model: ModelType,
755 pub prompt: String,
756 pub echo: bool,
757
758 #[serde(skip_serializing_if = "Option::is_none")]
759 pub frequency_penalty: Option<FrequencyPenalty>,
760 #[serde(skip_serializing_if = "Option::is_none")]
761 pub logprobs: Option<bool>,
762 #[serde(skip_serializing_if = "Option::is_none")]
763 pub max_tokens: Option<MaxToken>,
764 #[serde(skip_serializing_if = "Option::is_none")]
765 pub presence_penalty: Option<PresencePenalty>,
766 #[serde(skip_serializing_if = "Option::is_none")]
767 pub stop: Option<Stop>,
768 pub stream: bool,
769 #[serde(skip_serializing_if = "Option::is_none")]
770 pub stream_options: Option<StreamOptions>,
771 pub suffix: String,
772 #[serde(skip_serializing_if = "Option::is_none")]
773 pub temperature: Option<Temperature>,
774 #[serde(skip_serializing_if = "Option::is_none")]
775 pub top_p: Option<TopP>,
776}
777
778#[derive(Debug, Default)]
779pub struct FMICompletionsRequestBuilder {
780 model: ModelType,
781 prompt: String,
782 echo: bool,
783 frequency_penalty: Option<FrequencyPenalty>,
784 logprobs: Option<bool>,
785 max_tokens: Option<MaxToken>,
786 presence_penalty: Option<PresencePenalty>,
787 stop: Option<Stop>,
788 stream: bool,
789 stream_options: Option<StreamOptions>,
790 suffix: String,
791 temperature: Option<Temperature>,
792 top_p: Option<TopP>,
793}
794
795impl FMICompletionsRequestBuilder {
796 pub fn new(prompt: &str, suffix: &str) -> Self {
797 Self {
798 model: ModelType::DeepSeekChat,
799 prompt: prompt.to_string(),
800 suffix: suffix.to_string(),
801 echo: false,
802 stream: false,
803 ..Default::default()
804 }
805 }
806
807 pub fn use_model(mut self, model: ModelType) -> Self {
808 self.model = model;
809 self
810 }
811
812 pub fn echo(mut self, value: bool) -> Self {
813 self.echo = value;
814 self
815 }
816
817 pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
818 self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
819 Ok(self)
820 }
821
822 pub fn logprobs(mut self, value: bool) -> Self {
823 self.logprobs = Some(value);
824 self
825 }
826
827 pub fn max_tokens(mut self, value: u32) -> Result<Self> {
828 self.max_tokens = Some(MaxToken::new(value)?);
829 Ok(self)
830 }
831
832 pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
833 self.presence_penalty = Some(PresencePenalty::new(value)?);
834 Ok(self)
835 }
836
837 pub fn stop(mut self, value: Stop) -> Self {
838 self.stop = Some(value);
839 self
840 }
841
842 pub fn stream(mut self, value: bool) -> Self {
843 self.stream = value;
844 self
845 }
846
847 pub fn stream_options(mut self, value: StreamOptions) -> Self {
848 self.stream_options = Some(value);
849 self
850 }
851
852 pub fn temperature(mut self, value: u32) -> Result<Self> {
853 self.temperature = Some(Temperature::new(value)?);
854 Ok(self)
855 }
856
857 pub fn top_p(mut self, value: f32) -> Result<Self> {
858 self.top_p = Some(TopP::new(value)?);
859 Ok(self)
860 }
861}
862
863impl RequestBuilder for FMICompletionsRequestBuilder {
864 type Request = FMICompletionsRequest;
865 type Response = ChatCompletion;
866 type Item = ChatCompletionStream<TextChoiceStream>;
867
868 fn is_beta(&self) -> bool {
869 true
870 }
871
872 fn is_stream(&self) -> bool {
873 self.stream
874 }
875
876 fn build(self) -> FMICompletionsRequest {
877 FMICompletionsRequest {
878 model: self.model,
879 prompt: self.prompt,
880 echo: self.echo,
881 frequency_penalty: self.frequency_penalty,
882 logprobs: self.logprobs,
883 max_tokens: self.max_tokens,
884 presence_penalty: self.presence_penalty,
885 stop: self.stop,
886 stream: self.stream,
887 stream_options: self.stream_options,
888 suffix: self.suffix,
889 temperature: self.temperature,
890 top_p: self.top_p,
891 }
892 }
893}