rig/providers/cohere/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::providers::cohere::CompletionModel;
5use crate::providers::cohere::completion::{
6    AssistantContent, Message, ToolCall, ToolCallFunction, ToolType, Usage,
7};
8use crate::streaming::RawStreamingChoice;
9use crate::telemetry::SpanCombinator;
10use crate::{json_utils, streaming};
11use async_stream::stream;
12use futures::StreamExt;
13use http::Method;
14use serde::{Deserialize, Serialize};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18#[derive(Debug, Deserialize)]
19#[serde(rename_all = "kebab-case", tag = "type")]
20enum StreamingEvent {
21    MessageStart,
22    ContentStart,
23    ContentDelta { delta: Option<Delta> },
24    ContentEnd,
25    ToolPlan,
26    ToolCallStart { delta: Option<Delta> },
27    ToolCallDelta { delta: Option<Delta> },
28    ToolCallEnd,
29    MessageEnd { delta: Option<MessageEndDelta> },
30}
31
32#[derive(Debug, Deserialize)]
33struct MessageContentDelta {
34    text: Option<String>,
35}
36
37#[derive(Debug, Deserialize)]
38struct MessageToolFunctionDelta {
39    name: Option<String>,
40    arguments: Option<String>,
41}
42
43#[derive(Debug, Deserialize)]
44struct MessageToolCallDelta {
45    id: Option<String>,
46    function: Option<MessageToolFunctionDelta>,
47}
48
49#[derive(Debug, Deserialize)]
50struct MessageDelta {
51    content: Option<MessageContentDelta>,
52    tool_calls: Option<MessageToolCallDelta>,
53}
54
55#[derive(Debug, Deserialize)]
56struct Delta {
57    message: Option<MessageDelta>,
58}
59
60#[derive(Debug, Deserialize)]
61struct MessageEndDelta {
62    usage: Option<Usage>,
63}
64
65#[derive(Clone, Serialize, Deserialize)]
66pub struct StreamingCompletionResponse {
67    pub usage: Option<Usage>,
68}
69
70impl GetTokenUsage for StreamingCompletionResponse {
71    fn token_usage(&self) -> Option<crate::completion::Usage> {
72        let tokens = self
73            .usage
74            .clone()
75            .and_then(|response| response.tokens)
76            .map(|tokens| {
77                (
78                    tokens.input_tokens.map(|x| x as u64),
79                    tokens.output_tokens.map(|y| y as u64),
80                )
81            });
82        let Some((Some(input), Some(output))) = tokens else {
83            return None;
84        };
85        let mut usage = crate::completion::Usage::new();
86        usage.input_tokens = input;
87        usage.output_tokens = output;
88        usage.total_tokens = input + output;
89
90        Some(usage)
91    }
92}
93
94impl<T> CompletionModel<T>
95where
96    T: HttpClientExt + Clone + 'static,
97{
98    pub(crate) async fn stream(
99        &self,
100        request: CompletionRequest,
101    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
102    {
103        let request = self.create_completion_request(request)?;
104        let span = if tracing::Span::current().is_disabled() {
105            info_span!(
106                target: "rig::completions",
107                "chat_streaming",
108                gen_ai.operation.name = "chat_streaming",
109                gen_ai.provider.name = "cohere",
110                gen_ai.request.model = self.model,
111                gen_ai.response.id = tracing::field::Empty,
112                gen_ai.response.model = self.model,
113                gen_ai.usage.output_tokens = tracing::field::Empty,
114                gen_ai.usage.input_tokens = tracing::field::Empty,
115                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
116                gen_ai.output.messages = tracing::field::Empty,
117            )
118        } else {
119            tracing::Span::current()
120        };
121
122        let request = json_utils::merge(request, serde_json::json!({"stream": true}));
123
124        tracing::debug!(
125            "Cohere streaming completion input: {}",
126            serde_json::to_string_pretty(&request)?
127        );
128
129        let body = serde_json::to_vec(&request)?;
130
131        let req = self
132            .client
133            .req(Method::POST, "/v2/chat")?
134            .body(body)
135            .unwrap();
136
137        let mut event_source = GenericEventSource::new(self.client.http_client(), req);
138
139        let stream = stream! {
140            let mut current_tool_call: Option<(String, String, String)> = None;
141            let mut text_response = String::new();
142            let mut tool_calls = Vec::new();
143
144            while let Some(event_result) = event_source.next().await {
145                match event_result {
146                    Ok(Event::Open) => {
147                        tracing::trace!("SSE connection opened");
148                        continue;
149                    }
150
151                    Ok(Event::Message(message)) => {
152                        let data_str = message.data.trim();
153                        if data_str.is_empty() || data_str == "[DONE]" {
154                            continue;
155                        }
156
157                        let event: StreamingEvent = match serde_json::from_str(data_str) {
158                            Ok(ev) => ev,
159                            Err(_) => {
160                                tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
161                                continue;
162                            }
163                        };
164
165                        match event {
166                            StreamingEvent::ContentDelta { delta: Some(delta) } => {
167                                let Some(message) = &delta.message else { continue; };
168                                let Some(content) = &message.content else { continue; };
169                                let Some(text) = &content.text else { continue; };
170
171                                text_response += text;
172
173                                yield Ok(RawStreamingChoice::Message(text.clone()));
174                            },
175
176                            StreamingEvent::MessageEnd { delta: Some(delta) } => {
177                                let message = Message::Assistant {
178                                    tool_calls: tool_calls.clone(),
179                                    content: vec![AssistantContent::Text { text: text_response.clone() }],
180                                    tool_plan: None,
181                                    citations: vec![]
182                                };
183
184                                let span = tracing::Span::current();
185                                span.record_token_usage(&delta.usage);
186                                span.record_model_output(&vec![message]);
187
188                                yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
189                                    usage: delta.usage.clone()
190                                }));
191                            },
192
193                            StreamingEvent::ToolCallStart { delta: Some(delta) } => {
194                                let Some(message) = &delta.message else { continue; };
195                                let Some(tool_calls) = &message.tool_calls else { continue; };
196                                let Some(id) = tool_calls.id.clone() else { continue; };
197                                let Some(function) = &tool_calls.function else { continue; };
198                                let Some(name) = function.name.clone() else { continue; };
199                                let Some(arguments) = function.arguments.clone() else { continue; };
200
201                                current_tool_call = Some((id, name, arguments));
202                            },
203
204                            StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
205                                let Some(message) = &delta.message else { continue; };
206                                let Some(tool_calls) = &message.tool_calls else { continue; };
207                                let Some(function) = &tool_calls.function else { continue; };
208                                let Some(arguments) = function.arguments.clone() else { continue; };
209
210                                let Some(tc) = current_tool_call.clone() else { continue; };
211                                current_tool_call = Some((tc.0.clone(), tc.1, format!("{}{}", tc.2, arguments)));
212
213                                // Emit the delta so UI can show progress
214                                yield Ok(RawStreamingChoice::ToolCallDelta {
215                                    id: tc.0,
216                                    delta: arguments,
217                                });
218                            },
219
220                            StreamingEvent::ToolCallEnd => {
221                                let Some(tc) = current_tool_call.clone() else { continue; };
222                                let Ok(args) = serde_json::from_str::<serde_json::Value>(&tc.2) else { continue; };
223
224                                tool_calls.push(ToolCall {
225                                    id: Some(tc.0.clone()),
226                                    r#type: Some(ToolType::Function),
227                                    function: Some(ToolCallFunction {
228                                        name: tc.1.clone(),
229                                        arguments: args.clone()
230                                    })
231                                });
232
233                                yield Ok(RawStreamingChoice::ToolCall {
234                                    id: tc.0,
235                                    name: tc.1,
236                                    arguments: args,
237                                    call_id: None
238                                });
239
240                                current_tool_call = None;
241                            },
242
243                            _ => {}
244                        }
245                    },
246                    Err(crate::http_client::Error::StreamEnded) => {
247                        break;
248                    }
249                    Err(err) => {
250                        tracing::error!(?err, "SSE error");
251                        yield Err(CompletionError::ResponseError(err.to_string()));
252                        break;
253                    }
254                }
255            }
256
257            event_source.close();
258        }.instrument(span);
259
260        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
261            stream,
262        )))
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use serde_json::json;
270
271    #[test]
272    fn test_message_content_delta_deserialization() {
273        let json = json!({
274            "type": "content-delta",
275            "delta": {
276                "message": {
277                    "content": {
278                        "text": "Hello world"
279                    }
280                }
281            }
282        });
283
284        let event: StreamingEvent = serde_json::from_value(json).unwrap();
285        match event {
286            StreamingEvent::ContentDelta { delta } => {
287                assert!(delta.is_some());
288                let message = delta.unwrap().message.unwrap();
289                let content = message.content.unwrap();
290                assert_eq!(content.text, Some("Hello world".to_string()));
291            }
292            _ => panic!("Expected ContentDelta"),
293        }
294    }
295
296    #[test]
297    fn test_tool_call_start_deserialization() {
298        let json = json!({
299            "type": "tool-call-start",
300            "delta": {
301                "message": {
302                    "tool_calls": {
303                        "id": "call_123",
304                        "function": {
305                            "name": "get_weather",
306                            "arguments": "{"
307                        }
308                    }
309                }
310            }
311        });
312
313        let event: StreamingEvent = serde_json::from_value(json).unwrap();
314        match event {
315            StreamingEvent::ToolCallStart { delta } => {
316                assert!(delta.is_some());
317                let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
318                assert_eq!(tool_call.id, Some("call_123".to_string()));
319                assert_eq!(
320                    tool_call.function.unwrap().name,
321                    Some("get_weather".to_string())
322                );
323            }
324            _ => panic!("Expected ToolCallStart"),
325        }
326    }
327
328    #[test]
329    fn test_tool_call_delta_deserialization() {
330        let json = json!({
331            "type": "tool-call-delta",
332            "delta": {
333                "message": {
334                    "tool_calls": {
335                        "function": {
336                            "arguments": "\"location\""
337                        }
338                    }
339                }
340            }
341        });
342
343        let event: StreamingEvent = serde_json::from_value(json).unwrap();
344        match event {
345            StreamingEvent::ToolCallDelta { delta } => {
346                assert!(delta.is_some());
347                let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
348                let function = tool_call.function.unwrap();
349                assert_eq!(function.arguments, Some("\"location\"".to_string()));
350            }
351            _ => panic!("Expected ToolCallDelta"),
352        }
353    }
354
355    #[test]
356    fn test_tool_call_end_deserialization() {
357        let json = json!({
358            "type": "tool-call-end"
359        });
360
361        let event: StreamingEvent = serde_json::from_value(json).unwrap();
362        match event {
363            StreamingEvent::ToolCallEnd => {
364                // Success
365            }
366            _ => panic!("Expected ToolCallEnd"),
367        }
368    }
369
370    #[test]
371    fn test_message_end_with_usage_deserialization() {
372        let json = json!({
373            "type": "message-end",
374            "delta": {
375                "usage": {
376                    "tokens": {
377                        "input_tokens": 100,
378                        "output_tokens": 50
379                    }
380                }
381            }
382        });
383
384        let event: StreamingEvent = serde_json::from_value(json).unwrap();
385        match event {
386            StreamingEvent::MessageEnd { delta } => {
387                assert!(delta.is_some());
388                let usage = delta.unwrap().usage.unwrap();
389                let tokens = usage.tokens.unwrap();
390                assert_eq!(tokens.input_tokens, Some(100.0));
391                assert_eq!(tokens.output_tokens, Some(50.0));
392            }
393            _ => panic!("Expected MessageEnd"),
394        }
395    }
396
397    #[test]
398    fn test_streaming_event_order() {
399        // Test that a typical sequence of events deserializes correctly
400        let events = vec![
401            json!({"type": "message-start"}),
402            json!({"type": "content-start"}),
403            json!({
404                "type": "content-delta",
405                "delta": {
406                    "message": {
407                        "content": {
408                            "text": "Sure, "
409                        }
410                    }
411                }
412            }),
413            json!({
414                "type": "content-delta",
415                "delta": {
416                    "message": {
417                        "content": {
418                            "text": "I can help with that."
419                        }
420                    }
421                }
422            }),
423            json!({"type": "content-end"}),
424            json!({"type": "tool-plan"}),
425            json!({
426                "type": "tool-call-start",
427                "delta": {
428                    "message": {
429                        "tool_calls": {
430                            "id": "call_abc",
431                            "function": {
432                                "name": "search",
433                                "arguments": ""
434                            }
435                        }
436                    }
437                }
438            }),
439            json!({
440                "type": "tool-call-delta",
441                "delta": {
442                    "message": {
443                        "tool_calls": {
444                            "function": {
445                                "arguments": "{\"query\":"
446                            }
447                        }
448                    }
449                }
450            }),
451            json!({
452                "type": "tool-call-delta",
453                "delta": {
454                    "message": {
455                        "tool_calls": {
456                            "function": {
457                                "arguments": "\"Rust\"}"
458                            }
459                        }
460                    }
461                }
462            }),
463            json!({"type": "tool-call-end"}),
464            json!({
465                "type": "message-end",
466                "delta": {
467                    "usage": {
468                        "tokens": {
469                            "input_tokens": 50,
470                            "output_tokens": 25
471                        }
472                    }
473                }
474            }),
475        ];
476
477        for (i, event_json) in events.iter().enumerate() {
478            let result = serde_json::from_value::<StreamingEvent>(event_json.clone());
479            assert!(
480                result.is_ok(),
481                "Failed to deserialize event at index {}: {:?}",
482                i,
483                result.err()
484            );
485        }
486    }
487}