Skip to main content

cortexai_llm_client/
response.rs

1//! Response parsing for LLM API responses
2//!
3//! This module provides runtime-agnostic parsing of responses from
4//! OpenAI, Anthropic, and OpenRouter APIs.
5
6use serde::{Deserialize, Serialize};
7
8use crate::error::{LlmClientError, Result};
9use crate::provider::Provider;
10
11/// Parsed LLM response
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct LlmResponse {
14    /// The generated text content
15    pub content: String,
16    /// Token usage information (if available)
17    pub usage: Option<Usage>,
18    /// Tool calls requested by the model (if any)
19    pub tool_calls: Vec<ToolCall>,
20    /// The finish reason
21    pub finish_reason: Option<String>,
22    /// Model used for generation
23    pub model: Option<String>,
24}
25
26/// Token usage information
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct Usage {
29    /// Number of tokens in the prompt
30    pub prompt_tokens: u32,
31    /// Number of tokens in the completion
32    pub completion_tokens: u32,
33    /// Total tokens used
34    pub total_tokens: u32,
35}
36
37/// Tool call requested by the model
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ToolCall {
40    /// Unique identifier for this tool call
41    pub id: String,
42    /// Name of the function to call
43    pub function_name: String,
44    /// JSON arguments for the function
45    pub arguments: String,
46}
47
48/// Streaming chunk from SSE response
49#[derive(Debug, Clone)]
50pub struct StreamChunk {
51    /// Content delta (partial text)
52    pub content: Option<String>,
53    /// Whether this is the final chunk
54    pub done: bool,
55    /// Finish reason (on final chunk)
56    pub finish_reason: Option<String>,
57    /// Tool call chunks (for streaming tool calls)
58    pub tool_call_chunk: Option<ToolCallChunk>,
59}
60
61/// Partial tool call information from streaming
62#[derive(Debug, Clone)]
63pub struct ToolCallChunk {
64    /// Index of the tool call (for ordering)
65    pub index: usize,
66    /// Tool call ID (first chunk only)
67    pub id: Option<String>,
68    /// Function name (first chunk only)
69    pub function_name: Option<String>,
70    /// Arguments delta
71    pub arguments_delta: Option<String>,
72}
73
74/// Response parser for different providers
75pub struct ResponseParser;
76
77impl ResponseParser {
78    /// Parse a complete (non-streaming) response
79    pub fn parse(provider: Provider, body: &str) -> Result<LlmResponse> {
80        match provider {
81            Provider::OpenAI | Provider::OpenRouter => Self::parse_openai(body),
82            Provider::Anthropic => Self::parse_anthropic(body),
83        }
84    }
85
86    /// Parse a streaming SSE line
87    pub fn parse_stream_line(provider: Provider, line: &str) -> Result<Option<StreamChunk>> {
88        // Skip empty lines and comments
89        let line = line.trim();
90        if line.is_empty() || line.starts_with(':') {
91            return Ok(None);
92        }
93
94        // Extract data from SSE format
95        let data = if let Some(stripped) = line.strip_prefix("data: ") {
96            stripped.trim()
97        } else {
98            return Ok(None);
99        };
100
101        // Check for stream end
102        if data == "[DONE]" {
103            return Ok(Some(StreamChunk {
104                content: None,
105                done: true,
106                finish_reason: None,
107                tool_call_chunk: None,
108            }));
109        }
110
111        match provider {
112            Provider::OpenAI | Provider::OpenRouter => Self::parse_openai_stream_chunk(data),
113            Provider::Anthropic => Self::parse_anthropic_stream_chunk(data),
114        }
115    }
116
117    fn parse_openai(body: &str) -> Result<LlmResponse> {
118        let json: serde_json::Value = serde_json::from_str(body)
119            .map_err(|e| LlmClientError::UnexpectedFormat(e.to_string()))?;
120
121        // Check for error response
122        if json.get("error").is_some() {
123            return Err(LlmClientError::from_api_response(&json));
124        }
125
126        let choices = json
127            .get("choices")
128            .and_then(|c| c.as_array())
129            .ok_or_else(|| {
130                LlmClientError::UnexpectedFormat("Missing 'choices' field".to_string())
131            })?;
132
133        let first_choice = choices
134            .first()
135            .ok_or_else(|| LlmClientError::UnexpectedFormat("Empty choices array".to_string()))?;
136
137        let message = first_choice.get("message").ok_or_else(|| {
138            LlmClientError::UnexpectedFormat("Missing 'message' field".to_string())
139        })?;
140
141        let content = message
142            .get("content")
143            .and_then(|c| c.as_str())
144            .unwrap_or("")
145            .to_string();
146
147        let finish_reason = first_choice
148            .get("finish_reason")
149            .and_then(|f| f.as_str())
150            .map(|s| s.to_string());
151
152        let model = json
153            .get("model")
154            .and_then(|m| m.as_str())
155            .map(|s| s.to_string());
156
157        // Parse tool calls if present
158        let tool_calls = Self::parse_openai_tool_calls(message);
159
160        // Parse usage
161        let usage = json.get("usage").and_then(|u| {
162            Some(Usage {
163                prompt_tokens: u.get("prompt_tokens")?.as_u64()? as u32,
164                completion_tokens: u.get("completion_tokens")?.as_u64()? as u32,
165                total_tokens: u.get("total_tokens")?.as_u64()? as u32,
166            })
167        });
168
169        Ok(LlmResponse {
170            content,
171            usage,
172            tool_calls,
173            finish_reason,
174            model,
175        })
176    }
177
178    fn parse_openai_tool_calls(message: &serde_json::Value) -> Vec<ToolCall> {
179        let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) else {
180            return Vec::new();
181        };
182
183        tool_calls
184            .iter()
185            .filter_map(|tc| {
186                let id = tc.get("id")?.as_str()?.to_string();
187                let function = tc.get("function")?;
188                let function_name = function.get("name")?.as_str()?.to_string();
189                let arguments = function.get("arguments")?.as_str()?.to_string();
190                Some(ToolCall {
191                    id,
192                    function_name,
193                    arguments,
194                })
195            })
196            .collect()
197    }
198
199    fn parse_anthropic(body: &str) -> Result<LlmResponse> {
200        let json: serde_json::Value = serde_json::from_str(body)
201            .map_err(|e| LlmClientError::UnexpectedFormat(e.to_string()))?;
202
203        // Check for error response
204        if json.get("error").is_some() {
205            return Err(LlmClientError::from_api_response(&json));
206        }
207
208        // Anthropic returns content as an array of blocks
209        let content_blocks = json
210            .get("content")
211            .and_then(|c| c.as_array())
212            .ok_or_else(|| {
213                LlmClientError::UnexpectedFormat("Missing 'content' field".to_string())
214            })?;
215
216        let mut content = String::new();
217        let mut tool_calls = Vec::new();
218
219        for block in content_blocks {
220            let block_type = block.get("type").and_then(|t| t.as_str()).unwrap_or("");
221
222            match block_type {
223                "text" => {
224                    if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
225                        if !content.is_empty() {
226                            content.push('\n');
227                        }
228                        content.push_str(text);
229                    }
230                }
231                "tool_use" => {
232                    if let (Some(id), Some(name), Some(input)) = (
233                        block.get("id").and_then(|i| i.as_str()),
234                        block.get("name").and_then(|n| n.as_str()),
235                        block.get("input"),
236                    ) {
237                        tool_calls.push(ToolCall {
238                            id: id.to_string(),
239                            function_name: name.to_string(),
240                            arguments: serde_json::to_string(input).unwrap_or_default(),
241                        });
242                    }
243                }
244                _ => {}
245            }
246        }
247
248        let finish_reason = json
249            .get("stop_reason")
250            .and_then(|s| s.as_str())
251            .map(|s| s.to_string());
252
253        let model = json
254            .get("model")
255            .and_then(|m| m.as_str())
256            .map(|s| s.to_string());
257
258        let usage = json.get("usage").and_then(|u| {
259            Some(Usage {
260                prompt_tokens: u.get("input_tokens")?.as_u64()? as u32,
261                completion_tokens: u.get("output_tokens")?.as_u64()? as u32,
262                total_tokens: (u.get("input_tokens")?.as_u64()?
263                    + u.get("output_tokens")?.as_u64()?) as u32,
264            })
265        });
266
267        Ok(LlmResponse {
268            content,
269            usage,
270            tool_calls,
271            finish_reason,
272            model,
273        })
274    }
275
276    fn parse_openai_stream_chunk(data: &str) -> Result<Option<StreamChunk>> {
277        let json: serde_json::Value = serde_json::from_str(data)
278            .map_err(|e| LlmClientError::UnexpectedFormat(e.to_string()))?;
279
280        let choices = json.get("choices").and_then(|c| c.as_array());
281        let Some(choices) = choices else {
282            return Ok(None);
283        };
284
285        let Some(first_choice) = choices.first() else {
286            return Ok(None);
287        };
288
289        let finish_reason = first_choice.get("finish_reason").and_then(|f| {
290            if f.is_null() {
291                None
292            } else {
293                f.as_str().map(|s| s.to_string())
294            }
295        });
296
297        let delta = first_choice.get("delta");
298
299        let content = delta
300            .and_then(|d| d.get("content"))
301            .and_then(|c| c.as_str())
302            .map(|s| s.to_string());
303
304        // Parse streaming tool calls
305        let tool_call_chunk = delta
306            .and_then(|d| d.get("tool_calls"))
307            .and_then(|t| t.as_array())
308            .and_then(|arr| arr.first())
309            .and_then(|tc| {
310                let index = tc.get("index")?.as_u64()? as usize;
311                let id = tc.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
312                let function = tc.get("function");
313                let function_name = function
314                    .and_then(|f| f.get("name"))
315                    .and_then(|n| n.as_str())
316                    .map(|s| s.to_string());
317                let arguments_delta = function
318                    .and_then(|f| f.get("arguments"))
319                    .and_then(|a| a.as_str())
320                    .map(|s| s.to_string());
321
322                Some(ToolCallChunk {
323                    index,
324                    id,
325                    function_name,
326                    arguments_delta,
327                })
328            });
329
330        let done = finish_reason.is_some();
331
332        Ok(Some(StreamChunk {
333            content,
334            done,
335            finish_reason,
336            tool_call_chunk,
337        }))
338    }
339
340    fn parse_anthropic_stream_chunk(data: &str) -> Result<Option<StreamChunk>> {
341        let json: serde_json::Value = serde_json::from_str(data)
342            .map_err(|e| LlmClientError::UnexpectedFormat(e.to_string()))?;
343
344        let event_type = json.get("type").and_then(|t| t.as_str()).unwrap_or("");
345
346        match event_type {
347            "content_block_delta" => {
348                let delta = json.get("delta");
349                let delta_type = delta
350                    .and_then(|d| d.get("type"))
351                    .and_then(|t| t.as_str())
352                    .unwrap_or("");
353
354                match delta_type {
355                    "text_delta" => {
356                        let content = delta
357                            .and_then(|d| d.get("text"))
358                            .and_then(|t| t.as_str())
359                            .map(|s| s.to_string());
360
361                        Ok(Some(StreamChunk {
362                            content,
363                            done: false,
364                            finish_reason: None,
365                            tool_call_chunk: None,
366                        }))
367                    }
368                    "input_json_delta" => {
369                        let partial_json = delta
370                            .and_then(|d| d.get("partial_json"))
371                            .and_then(|p| p.as_str())
372                            .map(|s| s.to_string());
373
374                        let index =
375                            json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
376
377                        Ok(Some(StreamChunk {
378                            content: None,
379                            done: false,
380                            finish_reason: None,
381                            tool_call_chunk: Some(ToolCallChunk {
382                                index,
383                                id: None,
384                                function_name: None,
385                                arguments_delta: partial_json,
386                            }),
387                        }))
388                    }
389                    _ => Ok(None),
390                }
391            }
392            "content_block_start" => {
393                let content_block = json.get("content_block");
394                let block_type = content_block
395                    .and_then(|b| b.get("type"))
396                    .and_then(|t| t.as_str())
397                    .unwrap_or("");
398
399                if block_type == "tool_use" {
400                    let id = content_block
401                        .and_then(|b| b.get("id"))
402                        .and_then(|i| i.as_str())
403                        .map(|s| s.to_string());
404                    let name = content_block
405                        .and_then(|b| b.get("name"))
406                        .and_then(|n| n.as_str())
407                        .map(|s| s.to_string());
408                    let index = json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
409
410                    Ok(Some(StreamChunk {
411                        content: None,
412                        done: false,
413                        finish_reason: None,
414                        tool_call_chunk: Some(ToolCallChunk {
415                            index,
416                            id,
417                            function_name: name,
418                            arguments_delta: None,
419                        }),
420                    }))
421                } else {
422                    Ok(None)
423                }
424            }
425            "message_delta" => {
426                let stop_reason = json
427                    .get("delta")
428                    .and_then(|d| d.get("stop_reason"))
429                    .and_then(|s| s.as_str())
430                    .map(|s| s.to_string());
431
432                Ok(Some(StreamChunk {
433                    content: None,
434                    done: stop_reason.is_some(),
435                    finish_reason: stop_reason,
436                    tool_call_chunk: None,
437                }))
438            }
439            "message_stop" => Ok(Some(StreamChunk {
440                content: None,
441                done: true,
442                finish_reason: Some("end_turn".to_string()),
443                tool_call_chunk: None,
444            })),
445            _ => Ok(None),
446        }
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_parse_openai_response() {
456        let body = r#"{
457            "id": "chatcmpl-123",
458            "object": "chat.completion",
459            "model": "gpt-4",
460            "choices": [{
461                "index": 0,
462                "message": {
463                    "role": "assistant",
464                    "content": "Hello, world!"
465                },
466                "finish_reason": "stop"
467            }],
468            "usage": {
469                "prompt_tokens": 10,
470                "completion_tokens": 5,
471                "total_tokens": 15
472            }
473        }"#;
474
475        let response = ResponseParser::parse(Provider::OpenAI, body).unwrap();
476        assert_eq!(response.content, "Hello, world!");
477        assert_eq!(response.finish_reason, Some("stop".to_string()));
478        assert_eq!(response.model, Some("gpt-4".to_string()));
479        assert!(response.tool_calls.is_empty());
480
481        let usage = response.usage.unwrap();
482        assert_eq!(usage.prompt_tokens, 10);
483        assert_eq!(usage.completion_tokens, 5);
484        assert_eq!(usage.total_tokens, 15);
485    }
486
487    #[test]
488    fn test_parse_openai_with_tool_calls() {
489        let body = r#"{
490            "id": "chatcmpl-123",
491            "model": "gpt-4",
492            "choices": [{
493                "index": 0,
494                "message": {
495                    "role": "assistant",
496                    "content": null,
497                    "tool_calls": [{
498                        "id": "call_123",
499                        "type": "function",
500                        "function": {
501                            "name": "get_weather",
502                            "arguments": "{\"location\": \"Paris\"}"
503                        }
504                    }]
505                },
506                "finish_reason": "tool_calls"
507            }]
508        }"#;
509
510        let response = ResponseParser::parse(Provider::OpenAI, body).unwrap();
511        assert_eq!(response.tool_calls.len(), 1);
512        assert_eq!(response.tool_calls[0].id, "call_123");
513        assert_eq!(response.tool_calls[0].function_name, "get_weather");
514        assert_eq!(
515            response.tool_calls[0].arguments,
516            "{\"location\": \"Paris\"}"
517        );
518    }
519
520    #[test]
521    fn test_parse_anthropic_response() {
522        let body = r#"{
523            "id": "msg_123",
524            "type": "message",
525            "role": "assistant",
526            "model": "claude-3-opus-20240229",
527            "content": [{
528                "type": "text",
529                "text": "Hello from Claude!"
530            }],
531            "stop_reason": "end_turn",
532            "usage": {
533                "input_tokens": 10,
534                "output_tokens": 5
535            }
536        }"#;
537
538        let response = ResponseParser::parse(Provider::Anthropic, body).unwrap();
539        assert_eq!(response.content, "Hello from Claude!");
540        assert_eq!(response.finish_reason, Some("end_turn".to_string()));
541        assert_eq!(response.model, Some("claude-3-opus-20240229".to_string()));
542
543        let usage = response.usage.unwrap();
544        assert_eq!(usage.prompt_tokens, 10);
545        assert_eq!(usage.completion_tokens, 5);
546        assert_eq!(usage.total_tokens, 15);
547    }
548
549    #[test]
550    fn test_parse_anthropic_with_tool_use() {
551        let body = r#"{
552            "id": "msg_123",
553            "type": "message",
554            "role": "assistant",
555            "model": "claude-3-opus-20240229",
556            "content": [{
557                "type": "tool_use",
558                "id": "toolu_123",
559                "name": "get_weather",
560                "input": {"location": "Paris"}
561            }],
562            "stop_reason": "tool_use"
563        }"#;
564
565        let response = ResponseParser::parse(Provider::Anthropic, body).unwrap();
566        assert_eq!(response.tool_calls.len(), 1);
567        assert_eq!(response.tool_calls[0].id, "toolu_123");
568        assert_eq!(response.tool_calls[0].function_name, "get_weather");
569    }
570
571    #[test]
572    fn test_parse_openai_stream_chunk() {
573        let data = r#"{"id":"chatcmpl-123","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#;
574
575        let chunk = ResponseParser::parse_stream_line(Provider::OpenAI, &format!("data: {}", data))
576            .unwrap()
577            .unwrap();
578
579        assert_eq!(chunk.content, Some("Hello".to_string()));
580        assert!(!chunk.done);
581        assert!(chunk.finish_reason.is_none());
582    }
583
584    #[test]
585    fn test_parse_stream_done() {
586        let chunk = ResponseParser::parse_stream_line(Provider::OpenAI, "data: [DONE]")
587            .unwrap()
588            .unwrap();
589
590        assert!(chunk.done);
591        assert!(chunk.content.is_none());
592    }
593
594    #[test]
595    fn test_parse_anthropic_stream_text_delta() {
596        let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
597
598        let chunk =
599            ResponseParser::parse_stream_line(Provider::Anthropic, &format!("data: {}", data))
600                .unwrap()
601                .unwrap();
602
603        assert_eq!(chunk.content, Some("Hello".to_string()));
604        assert!(!chunk.done);
605    }
606
607    #[test]
608    fn test_parse_error_response() {
609        let body = r#"{"error": {"message": "Invalid API key", "type": "invalid_request_error"}}"#;
610
611        let result = ResponseParser::parse(Provider::OpenAI, body);
612        assert!(result.is_err());
613        assert!(matches!(result, Err(LlmClientError::ApiError { .. })));
614    }
615}