rexis_llm/
response.rs

1//! # RSLLM Response Types
2//!
3//! Response types for chat completions, embeddings, and other LLM operations.
4//! Supports both streaming and non-streaming responses with usage tracking.
5
6use crate::{MessageRole, ToolCall};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Response from a chat completion request
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ChatResponse {
13    /// Generated content
14    pub content: String,
15
16    /// Model used for generation
17    pub model: String,
18
19    /// Usage statistics
20    pub usage: Option<Usage>,
21
22    /// Finish reason
23    pub finish_reason: Option<String>,
24
25    /// Tool calls made by the assistant
26    pub tool_calls: Option<Vec<ToolCall>>,
27
28    /// Response metadata
29    pub metadata: HashMap<String, serde_json::Value>,
30
31    /// Response timestamp
32    #[serde(with = "chrono::serde::ts_seconds_option")]
33    pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
34
35    /// Response ID (if provided by provider)
36    pub id: Option<String>,
37}
38
39impl ChatResponse {
40    /// Create a new chat response
41    pub fn new(content: impl Into<String>, model: impl Into<String>) -> Self {
42        Self {
43            content: content.into(),
44            model: model.into(),
45            usage: None,
46            finish_reason: None,
47            tool_calls: None,
48            metadata: HashMap::new(),
49            timestamp: Some(chrono::Utc::now()),
50            id: None,
51        }
52    }
53
54    /// Set usage statistics
55    pub fn with_usage(mut self, usage: Usage) -> Self {
56        self.usage = Some(usage);
57        self
58    }
59
60    /// Set finish reason
61    pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
62        self.finish_reason = Some(reason.into());
63        self
64    }
65
66    /// Set tool calls
67    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
68        self.tool_calls = Some(tool_calls);
69        self
70    }
71
72    /// Add metadata
73    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
74        self.metadata.insert(key.into(), value);
75        self
76    }
77
78    /// Set response ID
79    pub fn with_id(mut self, id: impl Into<String>) -> Self {
80        self.id = Some(id.into());
81        self
82    }
83
84    /// Check if the response contains tool calls
85    pub fn has_tool_calls(&self) -> bool {
86        self.tool_calls
87            .as_ref()
88            .map_or(false, |calls| !calls.is_empty())
89    }
90
91    /// Check if the response finished successfully
92    pub fn is_finished(&self) -> bool {
93        matches!(
94            self.finish_reason.as_deref(),
95            Some("stop") | Some("end_turn") | Some("tool_calls")
96        )
97    }
98
99    /// Check if the response was truncated due to length
100    pub fn is_truncated(&self) -> bool {
101        matches!(
102            self.finish_reason.as_deref(),
103            Some("length") | Some("max_tokens")
104        )
105    }
106
107    /// Get the content length
108    pub fn content_length(&self) -> usize {
109        self.content.len()
110    }
111}
112
113/// Response from a completion request (non-chat)
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct CompletionResponse {
116    /// Generated text
117    pub text: String,
118
119    /// Model used for generation
120    pub model: String,
121
122    /// Usage statistics
123    pub usage: Option<Usage>,
124
125    /// Finish reason
126    pub finish_reason: Option<String>,
127
128    /// Log probabilities (if requested)
129    pub logprobs: Option<LogProbs>,
130
131    /// Response metadata
132    pub metadata: HashMap<String, serde_json::Value>,
133
134    /// Response timestamp
135    #[serde(with = "chrono::serde::ts_seconds_option")]
136    pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
137
138    /// Response ID (if provided by provider)
139    pub id: Option<String>,
140}
141
142impl CompletionResponse {
143    /// Create a new completion response
144    pub fn new(text: impl Into<String>, model: impl Into<String>) -> Self {
145        Self {
146            text: text.into(),
147            model: model.into(),
148            usage: None,
149            finish_reason: None,
150            logprobs: None,
151            metadata: HashMap::new(),
152            timestamp: Some(chrono::Utc::now()),
153            id: None,
154        }
155    }
156
157    /// Set usage statistics
158    pub fn with_usage(mut self, usage: Usage) -> Self {
159        self.usage = Some(usage);
160        self
161    }
162
163    /// Set finish reason
164    pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
165        self.finish_reason = Some(reason.into());
166        self
167    }
168
169    /// Set log probabilities
170    pub fn with_logprobs(mut self, logprobs: LogProbs) -> Self {
171        self.logprobs = Some(logprobs);
172        self
173    }
174
175    /// Add metadata
176    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
177        self.metadata.insert(key.into(), value);
178        self
179    }
180
181    /// Set response ID
182    pub fn with_id(mut self, id: impl Into<String>) -> Self {
183        self.id = Some(id.into());
184        self
185    }
186}
187
188/// A single chunk in a streaming response
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct StreamChunk {
191    /// Content delta for this chunk
192    pub content: String,
193
194    /// Whether this is a delta (partial) or complete content
195    pub is_delta: bool,
196
197    /// Whether this is the final chunk
198    pub is_done: bool,
199
200    /// Model name
201    pub model: String,
202
203    /// Role of the message (if applicable)
204    pub role: Option<MessageRole>,
205
206    /// Tool calls delta (if applicable)
207    pub tool_calls_delta: Option<Vec<ToolCallDelta>>,
208
209    /// Finish reason (if this is the final chunk)
210    pub finish_reason: Option<String>,
211
212    /// Usage statistics (typically only in final chunk)
213    pub usage: Option<Usage>,
214
215    /// Chunk metadata
216    pub metadata: HashMap<String, serde_json::Value>,
217
218    /// Chunk timestamp
219    #[serde(with = "chrono::serde::ts_seconds_option")]
220    pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
221}
222
223impl StreamChunk {
224    /// Create a new stream chunk
225    pub fn new(
226        content: impl Into<String>,
227        model: impl Into<String>,
228        is_delta: bool,
229        is_done: bool,
230    ) -> Self {
231        Self {
232            content: content.into(),
233            is_delta,
234            is_done,
235            model: model.into(),
236            role: None,
237            tool_calls_delta: None,
238            finish_reason: None,
239            usage: None,
240            metadata: HashMap::new(),
241            timestamp: Some(chrono::Utc::now()),
242        }
243    }
244
245    /// Create a delta chunk
246    pub fn delta(content: impl Into<String>, model: impl Into<String>) -> Self {
247        Self::new(content, model, true, false)
248    }
249
250    /// Create a final chunk
251    pub fn done(model: impl Into<String>) -> Self {
252        Self::new("", model, false, true)
253    }
254
255    /// Set the role
256    pub fn with_role(mut self, role: MessageRole) -> Self {
257        self.role = Some(role);
258        self
259    }
260
261    /// Set tool calls delta
262    pub fn with_tool_calls_delta(mut self, delta: Vec<ToolCallDelta>) -> Self {
263        self.tool_calls_delta = Some(delta);
264        self
265    }
266
267    /// Set finish reason
268    pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
269        self.finish_reason = Some(reason.into());
270        self
271    }
272
273    /// Set usage statistics
274    pub fn with_usage(mut self, usage: Usage) -> Self {
275        self.usage = Some(usage);
276        self
277    }
278
279    /// Add metadata
280    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
281        self.metadata.insert(key.into(), value);
282        self
283    }
284
285    /// Check if this chunk has content
286    pub fn has_content(&self) -> bool {
287        !self.content.is_empty()
288    }
289
290    /// Check if this chunk has tool calls
291    pub fn has_tool_calls(&self) -> bool {
292        self.tool_calls_delta
293            .as_ref()
294            .map_or(false, |calls| !calls.is_empty())
295    }
296}
297
298/// Tool call delta for streaming responses
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct ToolCallDelta {
301    /// Tool call index
302    pub index: u32,
303
304    /// Tool call ID (if starting a new call)
305    pub id: Option<String>,
306
307    /// Tool call type (if starting a new call)
308    #[serde(rename = "type")]
309    pub call_type: Option<String>,
310
311    /// Function delta
312    pub function: Option<ToolFunctionDelta>,
313}
314
315/// Tool function delta for streaming responses
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct ToolFunctionDelta {
318    /// Function name (if starting a new call)
319    pub name: Option<String>,
320
321    /// Arguments delta (partial JSON string)
322    pub arguments: Option<String>,
323}
324
325/// Usage statistics for API calls
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct Usage {
328    /// Number of tokens in the prompt
329    pub prompt_tokens: u32,
330
331    /// Number of tokens in the completion
332    pub completion_tokens: u32,
333
334    /// Total number of tokens used
335    pub total_tokens: u32,
336
337    /// Number of cached tokens (if applicable)
338    pub cached_tokens: Option<u32>,
339
340    /// Reasoning tokens (for models with reasoning capabilities)
341    pub reasoning_tokens: Option<u32>,
342}
343
344impl Usage {
345    /// Create new usage statistics
346    pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
347        Self {
348            prompt_tokens,
349            completion_tokens,
350            total_tokens: prompt_tokens + completion_tokens,
351            cached_tokens: None,
352            reasoning_tokens: None,
353        }
354    }
355
356    /// Set cached tokens
357    pub fn with_cached_tokens(mut self, cached_tokens: u32) -> Self {
358        self.cached_tokens = Some(cached_tokens);
359        self
360    }
361
362    /// Set reasoning tokens
363    pub fn with_reasoning_tokens(mut self, reasoning_tokens: u32) -> Self {
364        self.reasoning_tokens = Some(reasoning_tokens);
365        self
366    }
367
368    /// Get effective prompt tokens (excluding cached)
369    pub fn effective_prompt_tokens(&self) -> u32 {
370        self.prompt_tokens - self.cached_tokens.unwrap_or(0)
371    }
372
373    /// Get total cost in tokens
374    pub fn total_cost(&self) -> u32 {
375        self.effective_prompt_tokens() + self.completion_tokens
376    }
377}
378
379/// Log probabilities for completion responses
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct LogProbs {
382    /// Top log probabilities for each token
383    pub token_logprobs: Vec<Option<f64>>,
384
385    /// Top alternative tokens and their log probabilities
386    pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
387
388    /// Text offset for each token
389    pub text_offset: Vec<usize>,
390}
391
392/// Embedding response
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct EmbeddingResponse {
395    /// Embedding vectors
396    pub embeddings: Vec<Vec<f32>>,
397
398    /// Model used for embeddings
399    pub model: String,
400
401    /// Usage statistics
402    pub usage: Option<Usage>,
403
404    /// Response metadata
405    pub metadata: HashMap<String, serde_json::Value>,
406
407    /// Response timestamp
408    #[serde(with = "chrono::serde::ts_seconds_option")]
409    pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
410}
411
412impl EmbeddingResponse {
413    /// Create a new embedding response
414    pub fn new(embeddings: Vec<Vec<f32>>, model: impl Into<String>) -> Self {
415        Self {
416            embeddings,
417            model: model.into(),
418            usage: None,
419            metadata: HashMap::new(),
420            timestamp: Some(chrono::Utc::now()),
421        }
422    }
423
424    /// Set usage statistics
425    pub fn with_usage(mut self, usage: Usage) -> Self {
426        self.usage = Some(usage);
427        self
428    }
429
430    /// Add metadata
431    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
432        self.metadata.insert(key.into(), value);
433        self
434    }
435
436    /// Get the number of embeddings
437    pub fn count(&self) -> usize {
438        self.embeddings.len()
439    }
440
441    /// Get the embedding dimension (if any embeddings exist)
442    pub fn dimension(&self) -> Option<usize> {
443        self.embeddings.first().map(|emb| emb.len())
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_chat_response_creation() {
453        let response = ChatResponse::new("Hello!", "gpt-4")
454            .with_finish_reason("stop")
455            .with_usage(Usage::new(10, 5));
456
457        assert_eq!(response.content, "Hello!");
458        assert_eq!(response.model, "gpt-4");
459        assert_eq!(response.finish_reason, Some("stop".to_string()));
460        assert!(response.usage.is_some());
461        assert!(response.is_finished());
462    }
463
464    #[test]
465    fn test_stream_chunk() {
466        let chunk = StreamChunk::delta("Hello", "gpt-4").with_role(MessageRole::Assistant);
467
468        assert_eq!(chunk.content, "Hello");
469        assert!(chunk.is_delta);
470        assert!(!chunk.is_done);
471        assert_eq!(chunk.role, Some(MessageRole::Assistant));
472        assert!(chunk.has_content());
473    }
474
475    #[test]
476    fn test_usage_calculation() {
477        let usage = Usage::new(100, 50).with_cached_tokens(20);
478
479        assert_eq!(usage.total_tokens, 150);
480        assert_eq!(usage.effective_prompt_tokens(), 80);
481        assert_eq!(usage.total_cost(), 130);
482    }
483}