hehe_llm/types/
request.rs1use hehe_core::{Message, Metadata, ToolDefinition};
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize)]
5pub struct CompletionRequest {
6 pub model: String,
7 pub messages: Vec<Message>,
8 #[serde(skip_serializing_if = "Option::is_none")]
9 pub system: Option<String>,
10 #[serde(skip_serializing_if = "Option::is_none")]
11 pub tools: Option<Vec<ToolDefinition>>,
12 #[serde(skip_serializing_if = "Option::is_none")]
13 pub tool_choice: Option<ToolChoice>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub max_tokens: Option<u32>,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub temperature: Option<f32>,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 pub top_p: Option<f32>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub stop: Option<Vec<String>>,
22 #[serde(default)]
23 pub stream: bool,
24 #[serde(default, skip_serializing_if = "Metadata::is_empty")]
25 pub metadata: Metadata,
26}
27
28impl CompletionRequest {
29 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
30 Self {
31 model: model.into(),
32 messages,
33 system: None,
34 tools: None,
35 tool_choice: None,
36 max_tokens: None,
37 temperature: None,
38 top_p: None,
39 stop: None,
40 stream: false,
41 metadata: Metadata::new(),
42 }
43 }
44
45 pub fn with_system(mut self, system: impl Into<String>) -> Self {
46 self.system = Some(system.into());
47 self
48 }
49
50 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
51 self.tools = Some(tools);
52 self
53 }
54
55 pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
56 self.tool_choice = Some(choice);
57 self
58 }
59
60 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
61 self.max_tokens = Some(max_tokens);
62 self
63 }
64
65 pub fn with_temperature(mut self, temperature: f32) -> Self {
66 self.temperature = Some(temperature);
67 self
68 }
69
70 pub fn with_top_p(mut self, top_p: f32) -> Self {
71 self.top_p = Some(top_p);
72 self
73 }
74
75 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
76 self.stop = Some(stop);
77 self
78 }
79
80 pub fn streaming(mut self) -> Self {
81 self.stream = true;
82 self
83 }
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize)]
87#[serde(rename_all = "snake_case")]
88pub enum ToolChoice {
89 Auto,
90 None,
91 Required,
92 Tool { name: String },
93}
94
95impl Default for ToolChoice {
96 fn default() -> Self {
97 Self::Auto
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_completion_request_builder() {
107 let req = CompletionRequest::new("gpt-4", vec![Message::user("Hello")])
108 .with_system("You are helpful")
109 .with_max_tokens(1000)
110 .with_temperature(0.7)
111 .streaming();
112
113 assert_eq!(req.model, "gpt-4");
114 assert_eq!(req.system, Some("You are helpful".into()));
115 assert_eq!(req.max_tokens, Some(1000));
116 assert!(req.stream);
117 }
118
119 #[test]
120 fn test_tool_choice_serialization() {
121 let auto = serde_json::to_string(&ToolChoice::Auto).unwrap();
122 assert_eq!(auto, "\"auto\"");
123
124 let tool = serde_json::to_string(&ToolChoice::Tool {
125 name: "search".into(),
126 })
127 .unwrap();
128 assert!(tool.contains("search"));
129 }
130}