hehe_llm/types/
response.rs

1use hehe_core::{event::TokenUsage, stream::StopReason, Message, Metadata};
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize)]
5pub struct CompletionResponse {
6    pub id: String,
7    pub model: String,
8    pub message: Message,
9    #[serde(skip_serializing_if = "Option::is_none")]
10    pub stop_reason: Option<StopReason>,
11    pub usage: TokenUsage,
12    #[serde(default, skip_serializing_if = "Metadata::is_empty")]
13    pub metadata: Metadata,
14}
15
16impl CompletionResponse {
17    pub fn new(id: impl Into<String>, model: impl Into<String>, message: Message) -> Self {
18        Self {
19            id: id.into(),
20            model: model.into(),
21            message,
22            stop_reason: None,
23            usage: TokenUsage::default(),
24            metadata: Metadata::new(),
25        }
26    }
27
28    pub fn with_stop_reason(mut self, reason: StopReason) -> Self {
29        self.stop_reason = Some(reason);
30        self
31    }
32
33    pub fn with_usage(mut self, usage: TokenUsage) -> Self {
34        self.usage = usage;
35        self
36    }
37
38    pub fn text_content(&self) -> String {
39        self.message.text_content()
40    }
41
42    pub fn has_tool_use(&self) -> bool {
43        self.message.has_tool_use()
44    }
45}
46
47#[derive(Clone, Debug)]
48pub struct ModelInfo {
49    pub id: String,
50    pub name: String,
51    pub provider: String,
52    pub context_window: Option<u32>,
53    pub max_output_tokens: Option<u32>,
54    pub supports_tools: bool,
55    pub supports_vision: bool,
56    pub supports_streaming: bool,
57}
58
59impl ModelInfo {
60    pub fn new(id: impl Into<String>, provider: impl Into<String>) -> Self {
61        let id = id.into();
62        Self {
63            name: id.clone(),
64            id,
65            provider: provider.into(),
66            context_window: None,
67            max_output_tokens: None,
68            supports_tools: false,
69            supports_vision: false,
70            supports_streaming: true,
71        }
72    }
73
74    pub fn with_name(mut self, name: impl Into<String>) -> Self {
75        self.name = name.into();
76        self
77    }
78
79    pub fn with_context_window(mut self, size: u32) -> Self {
80        self.context_window = Some(size);
81        self
82    }
83
84    pub fn with_max_output_tokens(mut self, max: u32) -> Self {
85        self.max_output_tokens = Some(max);
86        self
87    }
88
89    pub fn with_tools(mut self) -> Self {
90        self.supports_tools = true;
91        self
92    }
93
94    pub fn with_vision(mut self) -> Self {
95        self.supports_vision = true;
96        self
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_completion_response() {
106        let resp = CompletionResponse::new("resp-123", "gpt-4", Message::assistant("Hello"))
107            .with_stop_reason(StopReason::EndTurn)
108            .with_usage(TokenUsage::new(10, 5));
109
110        assert_eq!(resp.id, "resp-123");
111        assert_eq!(resp.text_content(), "Hello");
112        assert_eq!(resp.usage.total(), 15);
113    }
114
115    #[test]
116    fn test_model_info() {
117        let model = ModelInfo::new("gpt-4o", "openai")
118            .with_name("GPT-4o")
119            .with_context_window(128000)
120            .with_tools()
121            .with_vision();
122
123        assert_eq!(model.id, "gpt-4o");
124        assert!(model.supports_tools);
125        assert!(model.supports_vision);
126    }
127}