Skip to main content

rig/providers/cohere/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::providers::cohere::CompletionModel;
5use crate::providers::cohere::completion::{
6    AssistantContent, CohereCompletionRequest, Message, ToolCall, ToolCallFunction, ToolType, Usage,
7};
8use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent};
9use crate::telemetry::SpanCombinator;
10use crate::{json_utils, streaming};
11use async_stream::stream;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tracing::{Level, enabled, info_span};
15use tracing_futures::Instrument;
16
17#[derive(Debug, Deserialize)]
18#[serde(rename_all = "kebab-case", tag = "type")]
19enum StreamingEvent {
20    MessageStart,
21    ContentStart,
22    ContentDelta { delta: Option<Delta> },
23    ContentEnd,
24    ToolPlan,
25    ToolCallStart { delta: Option<Delta> },
26    ToolCallDelta { delta: Option<Delta> },
27    ToolCallEnd,
28    MessageEnd { delta: Option<MessageEndDelta> },
29}
30
31#[derive(Debug, Deserialize)]
32struct MessageContentDelta {
33    text: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolFunctionDelta {
38    name: Option<String>,
39    arguments: Option<String>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageToolCallDelta {
44    id: Option<String>,
45    function: Option<MessageToolFunctionDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct MessageDelta {
50    content: Option<MessageContentDelta>,
51    tool_calls: Option<MessageToolCallDelta>,
52}
53
54#[derive(Debug, Deserialize)]
55struct Delta {
56    message: Option<MessageDelta>,
57}
58
59#[derive(Debug, Deserialize)]
60struct MessageEndDelta {
61    usage: Option<Usage>,
62}
63
64#[derive(Clone, Serialize, Deserialize)]
65pub struct StreamingCompletionResponse {
66    pub usage: Option<Usage>,
67}
68
69impl GetTokenUsage for StreamingCompletionResponse {
70    fn token_usage(&self) -> Option<crate::completion::Usage> {
71        let tokens = self
72            .usage
73            .clone()
74            .and_then(|response| response.tokens)
75            .map(|tokens| {
76                (
77                    tokens.input_tokens.map(|x| x as u64),
78                    tokens.output_tokens.map(|y| y as u64),
79                )
80            });
81        let Some((Some(input), Some(output))) = tokens else {
82            return None;
83        };
84        let mut usage = crate::completion::Usage::new();
85        usage.input_tokens = input;
86        usage.output_tokens = output;
87        usage.total_tokens = input + output;
88
89        Some(usage)
90    }
91}
92
93impl<T> CompletionModel<T>
94where
95    T: HttpClientExt + Clone + 'static,
96{
97    pub(crate) async fn stream(
98        &self,
99        request: CompletionRequest,
100    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
101    {
102        let mut request = CohereCompletionRequest::try_from((self.model.as_ref(), request))?;
103        let span = if tracing::Span::current().is_disabled() {
104            info_span!(
105                target: "rig::completions",
106                "chat_streaming",
107                gen_ai.operation.name = "chat_streaming",
108                gen_ai.provider.name = "cohere",
109                gen_ai.request.model = self.model,
110                gen_ai.response.id = tracing::field::Empty,
111                gen_ai.response.model = self.model,
112                gen_ai.usage.output_tokens = tracing::field::Empty,
113                gen_ai.usage.input_tokens = tracing::field::Empty,
114                gen_ai.usage.cached_tokens = tracing::field::Empty,
115            )
116        } else {
117            tracing::Span::current()
118        };
119
120        let params = json_utils::merge(
121            request.additional_params.unwrap_or(serde_json::json!({})),
122            serde_json::json!({"stream": true}),
123        );
124
125        request.additional_params = Some(params);
126
127        if enabled!(Level::TRACE) {
128            tracing::trace!(
129                target: "rig::streaming",
130                "Cohere streaming completion input: {}",
131                serde_json::to_string_pretty(&request)?
132            );
133        }
134
135        let body = serde_json::to_vec(&request)?;
136
137        let req = self.client.post("/v2/chat")?.body(body).unwrap();
138
139        let mut event_source = GenericEventSource::new(self.client.clone(), req);
140
141        let stream = stream! {
142            let mut current_tool_call: Option<(String, String, String, String)> = None;
143            let mut text_response = String::new();
144            let mut tool_calls = Vec::new();
145            let mut final_usage = None;
146
147            while let Some(event_result) = event_source.next().await {
148                match event_result {
149                    Ok(Event::Open) => {
150                        tracing::trace!("SSE connection opened");
151                        continue;
152                    }
153
154                    Ok(Event::Message(message)) => {
155                        let data_str = message.data.trim();
156                        if data_str.is_empty() || data_str == "[DONE]" {
157                            continue;
158                        }
159
160                        let event: StreamingEvent = match serde_json::from_str(data_str) {
161                            Ok(ev) => ev,
162                            Err(_) => {
163                                tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
164                                continue;
165                            }
166                        };
167
168                        match event {
169                            StreamingEvent::ContentDelta { delta: Some(delta) } => {
170                                let Some(message) = &delta.message else { continue; };
171                                let Some(content) = &message.content else { continue; };
172                                let Some(text) = &content.text else { continue; };
173
174                                text_response += text;
175
176                                yield Ok(RawStreamingChoice::Message(text.clone()));
177                            },
178
179                            StreamingEvent::MessageEnd { delta: Some(delta) } => {
180                                let message = Message::Assistant {
181                                    tool_calls: tool_calls.clone(),
182                                    content: vec![AssistantContent::Text { text: text_response.clone() }],
183                                    tool_plan: None,
184                                    citations: vec![]
185                                };
186
187                                let span = tracing::Span::current();
188                                span.record_token_usage(&delta.usage);
189                                span.record_model_output(&vec![message]);
190
191                                final_usage = Some(delta.usage.clone());
192                                break;
193                            },
194
195                            StreamingEvent::ToolCallStart { delta: Some(delta) } => {
196                                let Some(message) = &delta.message else { continue; };
197                                let Some(tool_calls) = &message.tool_calls else { continue; };
198                                let Some(id) = tool_calls.id.clone() else { continue; };
199                                let Some(function) = &tool_calls.function else { continue; };
200                                let Some(name) = function.name.clone() else { continue; };
201                                let Some(arguments) = function.arguments.clone() else { continue; };
202
203                                let internal_call_id = nanoid::nanoid!();
204                                current_tool_call = Some((id.clone(), internal_call_id.clone(), name.clone(), arguments));
205
206                                yield Ok(RawStreamingChoice::ToolCallDelta {
207                                    id,
208                                    internal_call_id,
209                                    content: ToolCallDeltaContent::Name(name),
210                                });
211                            },
212
213                            StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
214                                let Some(message) = &delta.message else { continue; };
215                                let Some(tool_calls) = &message.tool_calls else { continue; };
216                                let Some(function) = &tool_calls.function else { continue; };
217                                let Some(arguments) = function.arguments.clone() else { continue; };
218
219                                let Some(tc) = current_tool_call.clone() else { continue; };
220                                current_tool_call = Some((tc.0.clone(), tc.1.clone(), tc.2, format!("{}{}", tc.3, arguments)));
221
222                                // Emit the delta so UI can show progress
223                                yield Ok(RawStreamingChoice::ToolCallDelta {
224                                    id: tc.0,
225                                    internal_call_id: tc.1,
226                                    content: ToolCallDeltaContent::Delta(arguments),
227                                });
228                            },
229
230                            StreamingEvent::ToolCallEnd => {
231                                let Some(tc) = current_tool_call.clone() else { continue; };
232                                let Ok(args) = serde_json::from_str::<serde_json::Value>(&tc.3) else { continue; };
233
234                                tool_calls.push(ToolCall {
235                                    id: Some(tc.0.clone()),
236                                    r#type: Some(ToolType::Function),
237                                    function: Some(ToolCallFunction {
238                                        name: tc.2.clone(),
239                                        arguments: args.clone()
240                                    })
241                                });
242
243                                let raw_tool_call = RawStreamingToolCall::new(tc.0, tc.2, args)
244                                    .with_internal_call_id(tc.1);
245                                yield Ok(RawStreamingChoice::ToolCall(raw_tool_call));
246
247                                current_tool_call = None;
248                            },
249
250                            _ => {}
251                        }
252                    },
253                    Err(crate::http_client::Error::StreamEnded) => {
254                        break;
255                    }
256                    Err(err) => {
257                        tracing::error!(?err, "SSE error");
258                        yield Err(CompletionError::ProviderError(err.to_string()));
259                        break;
260                    }
261                }
262            }
263
264            // Ensure event source is closed when stream ends
265            event_source.close();
266
267            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
268                usage: final_usage.unwrap_or_default()
269            }))
270        }.instrument(span);
271
272        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
273            stream,
274        )))
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use serde_json::json;
282
283    #[test]
284    fn test_message_content_delta_deserialization() {
285        let json = json!({
286            "type": "content-delta",
287            "delta": {
288                "message": {
289                    "content": {
290                        "text": "Hello world"
291                    }
292                }
293            }
294        });
295
296        let event: StreamingEvent = serde_json::from_value(json).unwrap();
297        match event {
298            StreamingEvent::ContentDelta { delta } => {
299                assert!(delta.is_some());
300                let message = delta.unwrap().message.unwrap();
301                let content = message.content.unwrap();
302                assert_eq!(content.text, Some("Hello world".to_string()));
303            }
304            _ => panic!("Expected ContentDelta"),
305        }
306    }
307
308    #[test]
309    fn test_tool_call_start_deserialization() {
310        let json = json!({
311            "type": "tool-call-start",
312            "delta": {
313                "message": {
314                    "tool_calls": {
315                        "id": "call_123",
316                        "function": {
317                            "name": "get_weather",
318                            "arguments": "{"
319                        }
320                    }
321                }
322            }
323        });
324
325        let event: StreamingEvent = serde_json::from_value(json).unwrap();
326        match event {
327            StreamingEvent::ToolCallStart { delta } => {
328                assert!(delta.is_some());
329                let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
330                assert_eq!(tool_call.id, Some("call_123".to_string()));
331                assert_eq!(
332                    tool_call.function.unwrap().name,
333                    Some("get_weather".to_string())
334                );
335            }
336            _ => panic!("Expected ToolCallStart"),
337        }
338    }
339
340    #[test]
341    fn test_tool_call_delta_deserialization() {
342        let json = json!({
343            "type": "tool-call-delta",
344            "delta": {
345                "message": {
346                    "tool_calls": {
347                        "function": {
348                            "arguments": "\"location\""
349                        }
350                    }
351                }
352            }
353        });
354
355        let event: StreamingEvent = serde_json::from_value(json).unwrap();
356        match event {
357            StreamingEvent::ToolCallDelta { delta } => {
358                assert!(delta.is_some());
359                let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
360                let function = tool_call.function.unwrap();
361                assert_eq!(function.arguments, Some("\"location\"".to_string()));
362            }
363            _ => panic!("Expected ToolCallDelta"),
364        }
365    }
366
367    #[test]
368    fn test_tool_call_end_deserialization() {
369        let json = json!({
370            "type": "tool-call-end"
371        });
372
373        let event: StreamingEvent = serde_json::from_value(json).unwrap();
374        match event {
375            StreamingEvent::ToolCallEnd => {
376                // Success
377            }
378            _ => panic!("Expected ToolCallEnd"),
379        }
380    }
381
382    #[test]
383    fn test_message_end_with_usage_deserialization() {
384        let json = json!({
385            "type": "message-end",
386            "delta": {
387                "usage": {
388                    "tokens": {
389                        "input_tokens": 100,
390                        "output_tokens": 50
391                    }
392                }
393            }
394        });
395
396        let event: StreamingEvent = serde_json::from_value(json).unwrap();
397        match event {
398            StreamingEvent::MessageEnd { delta } => {
399                assert!(delta.is_some());
400                let usage = delta.unwrap().usage.unwrap();
401                let tokens = usage.tokens.unwrap();
402                assert_eq!(tokens.input_tokens, Some(100.0));
403                assert_eq!(tokens.output_tokens, Some(50.0));
404            }
405            _ => panic!("Expected MessageEnd"),
406        }
407    }
408
409    #[test]
410    fn test_streaming_event_order() {
411        // Test that a typical sequence of events deserializes correctly
412        let events = vec![
413            json!({"type": "message-start"}),
414            json!({"type": "content-start"}),
415            json!({
416                "type": "content-delta",
417                "delta": {
418                    "message": {
419                        "content": {
420                            "text": "Sure, "
421                        }
422                    }
423                }
424            }),
425            json!({
426                "type": "content-delta",
427                "delta": {
428                    "message": {
429                        "content": {
430                            "text": "I can help with that."
431                        }
432                    }
433                }
434            }),
435            json!({"type": "content-end"}),
436            json!({"type": "tool-plan"}),
437            json!({
438                "type": "tool-call-start",
439                "delta": {
440                    "message": {
441                        "tool_calls": {
442                            "id": "call_abc",
443                            "function": {
444                                "name": "search",
445                                "arguments": ""
446                            }
447                        }
448                    }
449                }
450            }),
451            json!({
452                "type": "tool-call-delta",
453                "delta": {
454                    "message": {
455                        "tool_calls": {
456                            "function": {
457                                "arguments": "{\"query\":"
458                            }
459                        }
460                    }
461                }
462            }),
463            json!({
464                "type": "tool-call-delta",
465                "delta": {
466                    "message": {
467                        "tool_calls": {
468                            "function": {
469                                "arguments": "\"Rust\"}"
470                            }
471                        }
472                    }
473                }
474            }),
475            json!({"type": "tool-call-end"}),
476            json!({
477                "type": "message-end",
478                "delta": {
479                    "usage": {
480                        "tokens": {
481                            "input_tokens": 50,
482                            "output_tokens": 25
483                        }
484                    }
485                }
486            }),
487        ];
488
489        for (i, event_json) in events.iter().enumerate() {
490            let result = serde_json::from_value::<StreamingEvent>(event_json.clone());
491            assert!(
492                result.is_ok(),
493                "Failed to deserialize event at index {}: {:?}",
494                i,
495                result.err()
496            );
497        }
498    }
499}