hehe_llm/types/
request.rs

1use hehe_core::{Message, Metadata, ToolDefinition};
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize)]
5pub struct CompletionRequest {
6    pub model: String,
7    pub messages: Vec<Message>,
8    #[serde(skip_serializing_if = "Option::is_none")]
9    pub system: Option<String>,
10    #[serde(skip_serializing_if = "Option::is_none")]
11    pub tools: Option<Vec<ToolDefinition>>,
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub tool_choice: Option<ToolChoice>,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub max_tokens: Option<u32>,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub temperature: Option<f32>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub top_p: Option<f32>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub stop: Option<Vec<String>>,
22    #[serde(default)]
23    pub stream: bool,
24    #[serde(default, skip_serializing_if = "Metadata::is_empty")]
25    pub metadata: Metadata,
26}
27
28impl CompletionRequest {
29    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
30        Self {
31            model: model.into(),
32            messages,
33            system: None,
34            tools: None,
35            tool_choice: None,
36            max_tokens: None,
37            temperature: None,
38            top_p: None,
39            stop: None,
40            stream: false,
41            metadata: Metadata::new(),
42        }
43    }
44
45    pub fn with_system(mut self, system: impl Into<String>) -> Self {
46        self.system = Some(system.into());
47        self
48    }
49
50    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
51        self.tools = Some(tools);
52        self
53    }
54
55    pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
56        self.tool_choice = Some(choice);
57        self
58    }
59
60    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
61        self.max_tokens = Some(max_tokens);
62        self
63    }
64
65    pub fn with_temperature(mut self, temperature: f32) -> Self {
66        self.temperature = Some(temperature);
67        self
68    }
69
70    pub fn with_top_p(mut self, top_p: f32) -> Self {
71        self.top_p = Some(top_p);
72        self
73    }
74
75    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
76        self.stop = Some(stop);
77        self
78    }
79
80    pub fn streaming(mut self) -> Self {
81        self.stream = true;
82        self
83    }
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize)]
87#[serde(rename_all = "snake_case")]
88pub enum ToolChoice {
89    Auto,
90    None,
91    Required,
92    Tool { name: String },
93}
94
95impl Default for ToolChoice {
96    fn default() -> Self {
97        Self::Auto
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn test_completion_request_builder() {
107        let req = CompletionRequest::new("gpt-4", vec![Message::user("Hello")])
108            .with_system("You are helpful")
109            .with_max_tokens(1000)
110            .with_temperature(0.7)
111            .streaming();
112
113        assert_eq!(req.model, "gpt-4");
114        assert_eq!(req.system, Some("You are helpful".into()));
115        assert_eq!(req.max_tokens, Some(1000));
116        assert!(req.stream);
117    }
118
119    #[test]
120    fn test_tool_choice_serialization() {
121        let auto = serde_json::to_string(&ToolChoice::Auto).unwrap();
122        assert_eq!(auto, "\"auto\"");
123
124        let tool = serde_json::to_string(&ToolChoice::Tool {
125            name: "search".into(),
126        })
127        .unwrap();
128        assert!(tool.contains("search"));
129    }
130}