1use crate::response::{
2 AssistantMessage, ChatCompletion, ChatCompletionStream, JSONChoiceStream, ModelType,
3 TextChoiceStream,
4};
5use anyhow::{anyhow, Ok, Result};
6use schemars::schema::SchemaObject;
7use serde::{de::DeserializeOwned, ser::SerializeStruct, Deserialize, Serialize, Serializer};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct FrequencyPenalty(pub f32);
12
13impl FrequencyPenalty {
14 pub fn new(v: f32) -> Result<Self> {
24 if !(-2.0..=2.0).contains(&v) {
25 return Err(anyhow!(
26 "Frequency penalty value must be between -2 and 2.".to_string()
27 ));
28 }
29 Ok(FrequencyPenalty(v))
30 }
31}
32
33impl Default for FrequencyPenalty {
34 fn default() -> Self {
36 FrequencyPenalty(0.0)
37 }
38}
39
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub struct PresencePenalty(pub f32);
43
44impl PresencePenalty {
45 pub fn new(v: f32) -> Result<Self> {
55 if !(-2.0..=2.0).contains(&v) {
56 return Err(anyhow!(
57 "Presence penalty value must be between -2 and 2.".to_string()
58 ));
59 }
60 Ok(PresencePenalty(v))
61 }
62}
63
64impl Default for PresencePenalty {
65 fn default() -> Self {
67 PresencePenalty(0.0)
68 }
69}
70
71#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
73pub enum ResponseType {
74 #[serde(rename = "json_object")]
75 Json,
76 #[serde(rename = "text")]
77 Text,
78}
79
80#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
82pub struct ResponseFormat {
83 #[serde(rename = "type")]
84 pub resp_type: ResponseType,
85}
86
87impl ResponseFormat {
88 pub fn new(rt: ResponseType) -> Self {
94 ResponseFormat { resp_type: rt }
95 }
96}
97
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
100pub struct MaxToken(pub u32);
101
102impl MaxToken {
103 pub fn new(v: u32) -> Result<Self> {
113 if !(1..=8192).contains(&v) {
114 return Err(anyhow!("Max token must be between 1 and 8192.".to_string()));
115 }
116 Ok(MaxToken(v))
117 }
118}
119
120impl Default for MaxToken {
121 fn default() -> Self {
123 MaxToken(4096)
124 }
125}
126
127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
129pub enum Stop {
130 Single(String),
131 Multiple(Vec<String>),
132}
133
134#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
136pub struct StreamOptions {
137 pub include_usage: bool,
138}
139
140impl StreamOptions {
141 pub fn new(include_usage: bool) -> Self {
147 StreamOptions { include_usage }
148 }
149}
150
151#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
153pub struct Temperature(pub u32);
154
155impl Temperature {
156 pub fn new(v: u32) -> Result<Self> {
166 if v > 2 {
167 return Err(anyhow!("Temperature must be between 0 and 2.".to_string()));
168 }
169 Ok(Temperature(v))
170 }
171}
172
173impl Default for Temperature {
174 fn default() -> Self {
176 Temperature(1)
177 }
178}
179
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
182pub struct TopP(pub f32);
183
184impl TopP {
185 pub fn new(v: f32) -> Result<Self> {
195 if !(0.0..=1.0).contains(&v) {
196 return Err(anyhow!("TopP value must be between 0and 2.".to_string()));
197 }
198 Ok(TopP(v))
199 }
200}
201
202impl Default for TopP {
203 fn default() -> Self {
205 TopP(1.0)
206 }
207}
208
209#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
211pub enum ToolType {
212 #[serde(rename = "function")]
213 Function,
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
218pub struct Function {
219 pub description: String,
220 pub name: String,
221 pub parameters: SchemaObject,
222}
223
224#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
226pub struct ToolObject {
227 #[serde(rename = "type")]
228 pub tool_type: ToolType,
229 pub function: Function,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
234pub enum ChatCompletionToolChoice {
235 #[serde(rename = "none")]
236 None,
237 #[serde(rename = "auto")]
238 Auto,
239 #[serde(rename = "required")]
240 Required,
241}
242
243#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
245pub struct FunctionChoice {
246 pub name: String,
247}
248
249#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
251pub struct ChatCompletionNamedToolChoice {
252 #[serde(rename = "type")]
253 pub tool_type: ToolType,
254 pub function: FunctionChoice,
255}
256
257#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
259pub enum ToolChoice {
260 ChatCompletion(ChatCompletionToolChoice),
261 ChatCompletionNamed(ChatCompletionNamedToolChoice),
262}
263
264#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
266pub struct TopLogprobs(pub u32);
267
268impl TopLogprobs {
269 pub fn new(v: u32) -> Result<Self> {
279 if v > 20 {
280 return Err(anyhow!(
281 "Top log probs must be between 0 and 20.".to_string()
282 ));
283 }
284 Ok(TopLogprobs(v))
285 }
286}
287
288impl Default for TopLogprobs {
289 fn default() -> Self {
291 TopLogprobs(0)
292 }
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
297#[serde(tag = "role")]
298pub enum MessageRequest {
299 #[serde(rename = "system")]
300 System(SystemMessageRequest),
301 #[serde(rename = "user")]
302 User(UserMessageRequest),
303 #[serde(rename = "assistant")]
304 Assistant(AssistantMessage),
305 #[serde(rename = "tool")]
306 Tool(ToolMessageRequest),
307}
308
309impl MessageRequest {
310 pub fn get_content(&self) -> &str {
311 match self {
312 MessageRequest::System(req) => req.content.as_str(),
313 MessageRequest::User(req) => req.content.as_str(),
314 MessageRequest::Assistant(req) => req.content.as_str(),
315 MessageRequest::Tool(req) => req.content.as_str(),
316 }
317 }
318}
319
320#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
322pub struct SystemMessageRequest {
323 pub content: String,
324 pub name: Option<String>,
325}
326
327impl SystemMessageRequest {
328 pub fn new(msg: &str) -> Self {
334 SystemMessageRequest {
335 content: msg.to_string(),
336 name: None,
337 }
338 }
339
340 pub fn new_with_name(name: &str, msg: &str) -> Self {
347 SystemMessageRequest {
348 content: msg.to_string(),
349 name: Some(name.to_string()),
350 }
351 }
352}
353
354#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
356pub struct UserMessageRequest {
357 pub content: String,
358 pub name: Option<String>,
359}
360
361impl UserMessageRequest {
362 pub fn new(msg: &str) -> Self {
368 UserMessageRequest {
369 content: msg.to_string(),
370 name: None,
371 }
372 }
373
374 pub fn new_with_name(name: &str, msg: &str) -> Self {
381 UserMessageRequest {
382 content: msg.to_string(),
383 name: Some(name.to_string()),
384 }
385 }
386}
387
388#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
390pub struct ToolMessageRequest {
391 pub content: String,
392 pub tool_call_id: String,
393}
394
395impl ToolMessageRequest {
396 pub fn new(msg: &str, tool_call_id: &str) -> Self {
403 ToolMessageRequest {
404 content: msg.to_string(),
405 tool_call_id: tool_call_id.to_string(),
406 }
407 }
408}
409
410pub trait RequestBuilder {
411 type Request: Serialize;
412 type Response: DeserializeOwned;
413 type Item: DeserializeOwned + Send + 'static;
414
415 fn is_beta(&self) -> bool;
416 fn is_stream(&self) -> bool;
417 fn build(self) -> Self::Request;
418}
419
420#[derive(Debug, Default, Clone, Deserialize)]
422pub struct CompletionsRequest {
423 pub messages: Vec<MessageRequest>,
424 pub model: ModelType,
425 pub prompt: String,
426 #[serde(skip_serializing_if = "Option::is_none")]
427 pub max_tokens: Option<MaxToken>,
428 #[serde(skip_serializing_if = "Option::is_none")]
429 pub response_format: Option<ResponseFormat>,
430 #[serde(skip_serializing_if = "Option::is_none")]
431 pub stop: Option<Stop>,
432 pub stream: bool,
433 #[serde(skip_serializing_if = "Option::is_none")]
434 pub stream_options: Option<StreamOptions>,
435 #[serde(skip_serializing_if = "Option::is_none")]
436 pub tools: Option<Vec<ToolObject>>,
437 #[serde(skip_serializing_if = "Option::is_none")]
438 pub tool_choice: Option<ToolChoice>,
439
440 #[serde(skip_serializing_if = "Option::is_none")]
442 pub temperature: Option<Temperature>,
443 #[serde(skip_serializing_if = "Option::is_none")]
444 pub top_p: Option<TopP>,
445 #[serde(skip_serializing_if = "Option::is_none")]
446 pub presence_penalty: Option<PresencePenalty>,
447 #[serde(skip_serializing_if = "Option::is_none")]
448 pub frequency_penalty: Option<FrequencyPenalty>,
449 #[serde(skip_serializing_if = "Option::is_none")]
450 pub logprobs: Option<bool>,
451 #[serde(skip_serializing_if = "Option::is_none")]
452 pub top_logprobs: Option<TopLogprobs>,
453}
454
455impl Serialize for CompletionsRequest {
456 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
457 where
458 S: Serializer,
459 {
460 let mut state = serializer.serialize_struct("CompletionsRequest", 12)?;
461
462 state.serialize_field("messages", &self.messages)?;
463 state.serialize_field("model", &self.model)?;
464 state.serialize_field("max_tokens", &self.max_tokens)?;
465 state.serialize_field("response_format", &self.response_format)?;
466 state.serialize_field("stop", &self.stop)?;
467 state.serialize_field("stream", &self.stream)?;
468 state.serialize_field("stream_options", &self.stream_options)?;
469 state.serialize_field("tools", &self.tools)?;
470 state.serialize_field("tool_choice", &self.tool_choice)?;
471 state.serialize_field("prompt", &self.prompt)?;
472
473 if self.model != ModelType::DeepSeekReasoner {
475 state.serialize_field("temperature", &self.temperature)?;
476 state.serialize_field("top_p", &self.top_p)?;
477 state.serialize_field("presence_penalty", &self.presence_penalty)?;
478 state.serialize_field("frequency_penalty", &self.frequency_penalty)?;
479 state.serialize_field("logprobs", &self.logprobs)?;
480 state.serialize_field("top_logprobs", &self.top_logprobs)?;
481 }
482
483 state.end()
484 }
485}
486
487#[derive(Debug, Default)]
488pub struct CompletionsRequestBuilder {
489 beta: bool,
491 messages: Vec<MessageRequest>,
492 model: ModelType,
493
494 stream: bool,
495 stream_options: Option<StreamOptions>,
496
497 max_tokens: Option<MaxToken>,
498 response_format: Option<ResponseFormat>,
499 stop: Option<Stop>,
500 tools: Option<Vec<ToolObject>>,
501 tool_choice: Option<ToolChoice>,
502 prompt: String,
503 temperature: Option<Temperature>,
504 top_p: Option<TopP>,
505 presence_penalty: Option<PresencePenalty>,
506 frequency_penalty: Option<FrequencyPenalty>,
507 logprobs: Option<bool>,
508 top_logprobs: Option<TopLogprobs>,
509}
510
511impl CompletionsRequestBuilder {
512 pub fn new(messages: Vec<MessageRequest>) -> Self {
513 Self {
514 messages,
515 model: ModelType::DeepSeekChat,
516 prompt: String::new(),
517 ..Default::default()
518 }
519 }
520 pub fn use_model(mut self, model: ModelType) -> Self {
521 self.model = model;
522 self
523 }
524
525 pub fn append_fim_message(self, _prompt: &str, _suffix: &str) -> Self {
527 todo!("Not enough detail in document")
528 }
529
530 pub fn append_prefix_message(mut self, msg: &str) -> Self {
532 self.messages.push(MessageRequest::Assistant(
533 AssistantMessage::new(msg).set_prefix(msg),
534 ));
535 self
536 }
537
538 pub fn append_user_message(mut self, msg: &str) -> Self {
539 self.messages
540 .push(MessageRequest::User(UserMessageRequest::new(msg)));
541 self
542 }
543
544 pub fn max_tokens(mut self, value: u32) -> Result<Self> {
545 self.max_tokens = Some(MaxToken::new(value)?);
546 Ok(self)
547 }
548
549 pub fn use_beta(mut self, value: bool) -> Self {
550 self.beta = value;
551 self
552 }
553
554 pub fn stream(mut self, value: bool) -> Self {
555 self.stream = value;
556 self
557 }
558
559 pub fn stream_options(mut self, value: StreamOptions) -> Self {
560 self.stream_options = Some(value);
561 self
562 }
563
564 pub fn response_format(mut self, value: ResponseType) -> Self {
565 self.response_format = Some(ResponseFormat { resp_type: value });
566 self
567 }
568
569 pub fn stop(mut self, value: Stop) -> Self {
570 self.stop = Some(value);
571 self
572 }
573
574 pub fn tools(mut self, value: Vec<ToolObject>) -> Self {
575 self.tools = Some(value);
576 self
577 }
578
579 pub fn tool_choice(mut self, value: ToolChoice) -> Self {
580 self.tool_choice = Some(value);
581 self
582 }
583
584 pub fn prompt(mut self, value: String) -> Self {
585 self.prompt = value;
586 self
587 }
588
589 pub fn temperature(mut self, value: u32) -> Result<Self> {
590 self.temperature = Some(Temperature::new(value)?);
591 Ok(self)
592 }
593
594 pub fn top_p(mut self, value: f32) -> Result<Self> {
595 self.top_p = Some(TopP::new(value)?);
596 Ok(self)
597 }
598
599 pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
600 self.presence_penalty = Some(PresencePenalty::new(value)?);
601 Ok(self)
602 }
603
604 pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
605 self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
606 Ok(self)
607 }
608
609 pub fn logprobs(mut self, value: bool) -> Self {
610 self.logprobs = Some(value);
611 self
612 }
613
614 pub fn top_logprobs(mut self, value: u32) -> Result<Self> {
615 self.top_logprobs = Some(TopLogprobs::new(value)?);
616 Ok(self)
617 }
618}
619
620impl RequestBuilder for CompletionsRequestBuilder {
621 type Request = CompletionsRequest;
622 type Response = ChatCompletion;
623 type Item = ChatCompletionStream<JSONChoiceStream>;
624
625 fn is_beta(&self) -> bool {
626 self.beta
627 }
628
629 fn is_stream(&self) -> bool {
630 self.stream
631 }
632
633 fn build(self) -> CompletionsRequest {
634 CompletionsRequest {
635 messages: self.messages,
636 model: self.model,
637 max_tokens: self.max_tokens,
638 response_format: self.response_format,
639 stop: self.stop,
640 stream: self.stream,
641 stream_options: self.stream_options,
642 tools: self.tools,
643 tool_choice: self.tool_choice,
644 prompt: self.prompt,
645 temperature: self.temperature,
646 top_p: self.top_p,
647 presence_penalty: self.presence_penalty,
648 frequency_penalty: self.frequency_penalty,
649 logprobs: self.logprobs,
650 top_logprobs: self.top_logprobs,
651 }
652 }
653}
654
655#[derive(Debug, Default, Clone, PartialEq, Serialize)]
657pub struct FMICompletionsRequest {
658 pub model: ModelType,
659 pub prompt: String,
660 pub echo: bool,
661
662 #[serde(skip_serializing_if = "Option::is_none")]
663 pub frequency_penalty: Option<FrequencyPenalty>,
664 #[serde(skip_serializing_if = "Option::is_none")]
665 pub logprobs: Option<bool>,
666 #[serde(skip_serializing_if = "Option::is_none")]
667 pub max_tokens: Option<MaxToken>,
668 #[serde(skip_serializing_if = "Option::is_none")]
669 pub presence_penalty: Option<PresencePenalty>,
670 #[serde(skip_serializing_if = "Option::is_none")]
671 pub stop: Option<Stop>,
672 pub stream: bool,
673 #[serde(skip_serializing_if = "Option::is_none")]
674 pub stream_options: Option<StreamOptions>,
675 pub suffix: String,
676 #[serde(skip_serializing_if = "Option::is_none")]
677 pub temperature: Option<Temperature>,
678 #[serde(skip_serializing_if = "Option::is_none")]
679 pub top_p: Option<TopP>,
680}
681
682#[derive(Debug, Default)]
683pub struct FMICompletionsRequestBuilder {
684 model: ModelType,
685 prompt: String,
686 echo: bool,
687 frequency_penalty: Option<FrequencyPenalty>,
688 logprobs: Option<bool>,
689 max_tokens: Option<MaxToken>,
690 presence_penalty: Option<PresencePenalty>,
691 stop: Option<Stop>,
692 stream: bool,
693 stream_options: Option<StreamOptions>,
694 suffix: String,
695 temperature: Option<Temperature>,
696 top_p: Option<TopP>,
697}
698
699impl FMICompletionsRequestBuilder {
700 pub fn new(prompt: &str, suffix: &str) -> Self {
701 Self {
702 model: ModelType::DeepSeekChat,
704 prompt: prompt.to_string(),
705 suffix: suffix.to_string(),
706 echo: false,
707 stream: false,
708 ..Default::default()
709 }
710 }
711
712 pub fn echo(mut self, value: bool) -> Self {
713 self.echo = value;
714 self
715 }
716
717 pub fn frequency_penalty(mut self, value: f32) -> Result<Self> {
718 self.frequency_penalty = Some(FrequencyPenalty::new(value)?);
719 Ok(self)
720 }
721
722 pub fn logprobs(mut self, value: bool) -> Self {
723 self.logprobs = Some(value);
724 self
725 }
726
727 pub fn max_tokens(mut self, value: u32) -> Result<Self> {
728 self.max_tokens = Some(MaxToken::new(value)?);
729 Ok(self)
730 }
731
732 pub fn presence_penalty(mut self, value: f32) -> Result<Self> {
733 self.presence_penalty = Some(PresencePenalty::new(value)?);
734 Ok(self)
735 }
736
737 pub fn stop(mut self, value: Stop) -> Self {
738 self.stop = Some(value);
739 self
740 }
741
742 pub fn stream(mut self, value: bool) -> Self {
743 self.stream = value;
744 self
745 }
746
747 pub fn stream_options(mut self, value: StreamOptions) -> Self {
748 self.stream_options = Some(value);
749 self
750 }
751
752 pub fn temperature(mut self, value: u32) -> Result<Self> {
753 self.temperature = Some(Temperature::new(value)?);
754 Ok(self)
755 }
756
757 pub fn top_p(mut self, value: f32) -> Result<Self> {
758 self.top_p = Some(TopP::new(value)?);
759 Ok(self)
760 }
761}
762
763impl RequestBuilder for FMICompletionsRequestBuilder {
764 type Request = FMICompletionsRequest;
765 type Response = ChatCompletion;
766 type Item = ChatCompletionStream<TextChoiceStream>;
767
768 fn is_beta(&self) -> bool {
769 true
770 }
771
772 fn is_stream(&self) -> bool {
773 self.stream
774 }
775
776 fn build(self) -> FMICompletionsRequest {
777 FMICompletionsRequest {
778 model: self.model,
779 prompt: self.prompt,
780 echo: self.echo,
781 frequency_penalty: self.frequency_penalty,
782 logprobs: self.logprobs,
783 max_tokens: self.max_tokens,
784 presence_penalty: self.presence_penalty,
785 stop: self.stop,
786 stream: self.stream,
787 stream_options: self.stream_options,
788 suffix: self.suffix,
789 temperature: self.temperature,
790 top_p: self.top_p,
791 }
792 }
793}