1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::providers::cohere::CompletionModel;
5use crate::providers::cohere::completion::{
6 AssistantContent, Message, ToolCall, ToolCallFunction, ToolType, Usage,
7};
8use crate::streaming::RawStreamingChoice;
9use crate::telemetry::SpanCombinator;
10use crate::{json_utils, streaming};
11use async_stream::stream;
12use futures::StreamExt;
13use http::Method;
14use serde::{Deserialize, Serialize};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18#[derive(Debug, Deserialize)]
19#[serde(rename_all = "kebab-case", tag = "type")]
20enum StreamingEvent {
21 MessageStart,
22 ContentStart,
23 ContentDelta { delta: Option<Delta> },
24 ContentEnd,
25 ToolPlan,
26 ToolCallStart { delta: Option<Delta> },
27 ToolCallDelta { delta: Option<Delta> },
28 ToolCallEnd,
29 MessageEnd { delta: Option<MessageEndDelta> },
30}
31
32#[derive(Debug, Deserialize)]
33struct MessageContentDelta {
34 text: Option<String>,
35}
36
37#[derive(Debug, Deserialize)]
38struct MessageToolFunctionDelta {
39 name: Option<String>,
40 arguments: Option<String>,
41}
42
43#[derive(Debug, Deserialize)]
44struct MessageToolCallDelta {
45 id: Option<String>,
46 function: Option<MessageToolFunctionDelta>,
47}
48
49#[derive(Debug, Deserialize)]
50struct MessageDelta {
51 content: Option<MessageContentDelta>,
52 tool_calls: Option<MessageToolCallDelta>,
53}
54
55#[derive(Debug, Deserialize)]
56struct Delta {
57 message: Option<MessageDelta>,
58}
59
60#[derive(Debug, Deserialize)]
61struct MessageEndDelta {
62 usage: Option<Usage>,
63}
64
65#[derive(Clone, Serialize, Deserialize)]
66pub struct StreamingCompletionResponse {
67 pub usage: Option<Usage>,
68}
69
70impl GetTokenUsage for StreamingCompletionResponse {
71 fn token_usage(&self) -> Option<crate::completion::Usage> {
72 let tokens = self
73 .usage
74 .clone()
75 .and_then(|response| response.tokens)
76 .map(|tokens| {
77 (
78 tokens.input_tokens.map(|x| x as u64),
79 tokens.output_tokens.map(|y| y as u64),
80 )
81 });
82 let Some((Some(input), Some(output))) = tokens else {
83 return None;
84 };
85 let mut usage = crate::completion::Usage::new();
86 usage.input_tokens = input;
87 usage.output_tokens = output;
88 usage.total_tokens = input + output;
89
90 Some(usage)
91 }
92}
93
94impl<T> CompletionModel<T>
95where
96 T: HttpClientExt + Clone + 'static,
97{
98 pub(crate) async fn stream(
99 &self,
100 request: CompletionRequest,
101 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
102 {
103 let request = self.create_completion_request(request)?;
104 let span = if tracing::Span::current().is_disabled() {
105 info_span!(
106 target: "rig::completions",
107 "chat_streaming",
108 gen_ai.operation.name = "chat_streaming",
109 gen_ai.provider.name = "cohere",
110 gen_ai.request.model = self.model,
111 gen_ai.response.id = tracing::field::Empty,
112 gen_ai.response.model = self.model,
113 gen_ai.usage.output_tokens = tracing::field::Empty,
114 gen_ai.usage.input_tokens = tracing::field::Empty,
115 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
116 gen_ai.output.messages = tracing::field::Empty,
117 )
118 } else {
119 tracing::Span::current()
120 };
121
122 let request = json_utils::merge(request, serde_json::json!({"stream": true}));
123
124 tracing::debug!(
125 "Cohere streaming completion input: {}",
126 serde_json::to_string_pretty(&request)?
127 );
128
129 let body = serde_json::to_vec(&request)?;
130
131 let req = self
132 .client
133 .req(Method::POST, "/v2/chat")?
134 .body(body)
135 .unwrap();
136
137 let mut event_source = GenericEventSource::new(self.client.http_client(), req);
138
139 let stream = stream! {
140 let mut current_tool_call: Option<(String, String, String)> = None;
141 let mut text_response = String::new();
142 let mut tool_calls = Vec::new();
143
144 while let Some(event_result) = event_source.next().await {
145 match event_result {
146 Ok(Event::Open) => {
147 tracing::trace!("SSE connection opened");
148 continue;
149 }
150
151 Ok(Event::Message(message)) => {
152 let data_str = message.data.trim();
153 if data_str.is_empty() || data_str == "[DONE]" {
154 continue;
155 }
156
157 let event: StreamingEvent = match serde_json::from_str(data_str) {
158 Ok(ev) => ev,
159 Err(_) => {
160 tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
161 continue;
162 }
163 };
164
165 match event {
166 StreamingEvent::ContentDelta { delta: Some(delta) } => {
167 let Some(message) = &delta.message else { continue; };
168 let Some(content) = &message.content else { continue; };
169 let Some(text) = &content.text else { continue; };
170
171 text_response += text;
172
173 yield Ok(RawStreamingChoice::Message(text.clone()));
174 },
175
176 StreamingEvent::MessageEnd { delta: Some(delta) } => {
177 let message = Message::Assistant {
178 tool_calls: tool_calls.clone(),
179 content: vec![AssistantContent::Text { text: text_response.clone() }],
180 tool_plan: None,
181 citations: vec![]
182 };
183
184 let span = tracing::Span::current();
185 span.record_token_usage(&delta.usage);
186 span.record_model_output(&vec![message]);
187
188 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
189 usage: delta.usage.clone()
190 }));
191 },
192
193 StreamingEvent::ToolCallStart { delta: Some(delta) } => {
194 let Some(message) = &delta.message else { continue; };
195 let Some(tool_calls) = &message.tool_calls else { continue; };
196 let Some(id) = tool_calls.id.clone() else { continue; };
197 let Some(function) = &tool_calls.function else { continue; };
198 let Some(name) = function.name.clone() else { continue; };
199 let Some(arguments) = function.arguments.clone() else { continue; };
200
201 current_tool_call = Some((id, name, arguments));
202 },
203
204 StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
205 let Some(message) = &delta.message else { continue; };
206 let Some(tool_calls) = &message.tool_calls else { continue; };
207 let Some(function) = &tool_calls.function else { continue; };
208 let Some(arguments) = function.arguments.clone() else { continue; };
209
210 let Some(tc) = current_tool_call.clone() else { continue; };
211 current_tool_call = Some((tc.0.clone(), tc.1, format!("{}{}", tc.2, arguments)));
212
213 yield Ok(RawStreamingChoice::ToolCallDelta {
215 id: tc.0,
216 delta: arguments,
217 });
218 },
219
220 StreamingEvent::ToolCallEnd => {
221 let Some(tc) = current_tool_call.clone() else { continue; };
222 let Ok(args) = serde_json::from_str::<serde_json::Value>(&tc.2) else { continue; };
223
224 tool_calls.push(ToolCall {
225 id: Some(tc.0.clone()),
226 r#type: Some(ToolType::Function),
227 function: Some(ToolCallFunction {
228 name: tc.1.clone(),
229 arguments: args.clone()
230 })
231 });
232
233 yield Ok(RawStreamingChoice::ToolCall {
234 id: tc.0,
235 name: tc.1,
236 arguments: args,
237 call_id: None
238 });
239
240 current_tool_call = None;
241 },
242
243 _ => {}
244 }
245 },
246 Err(crate::http_client::Error::StreamEnded) => {
247 break;
248 }
249 Err(err) => {
250 tracing::error!(?err, "SSE error");
251 yield Err(CompletionError::ResponseError(err.to_string()));
252 break;
253 }
254 }
255 }
256
257 event_source.close();
258 }.instrument(span);
259
260 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
261 stream,
262 )))
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use serde_json::json;
270
271 #[test]
272 fn test_message_content_delta_deserialization() {
273 let json = json!({
274 "type": "content-delta",
275 "delta": {
276 "message": {
277 "content": {
278 "text": "Hello world"
279 }
280 }
281 }
282 });
283
284 let event: StreamingEvent = serde_json::from_value(json).unwrap();
285 match event {
286 StreamingEvent::ContentDelta { delta } => {
287 assert!(delta.is_some());
288 let message = delta.unwrap().message.unwrap();
289 let content = message.content.unwrap();
290 assert_eq!(content.text, Some("Hello world".to_string()));
291 }
292 _ => panic!("Expected ContentDelta"),
293 }
294 }
295
296 #[test]
297 fn test_tool_call_start_deserialization() {
298 let json = json!({
299 "type": "tool-call-start",
300 "delta": {
301 "message": {
302 "tool_calls": {
303 "id": "call_123",
304 "function": {
305 "name": "get_weather",
306 "arguments": "{"
307 }
308 }
309 }
310 }
311 });
312
313 let event: StreamingEvent = serde_json::from_value(json).unwrap();
314 match event {
315 StreamingEvent::ToolCallStart { delta } => {
316 assert!(delta.is_some());
317 let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
318 assert_eq!(tool_call.id, Some("call_123".to_string()));
319 assert_eq!(
320 tool_call.function.unwrap().name,
321 Some("get_weather".to_string())
322 );
323 }
324 _ => panic!("Expected ToolCallStart"),
325 }
326 }
327
328 #[test]
329 fn test_tool_call_delta_deserialization() {
330 let json = json!({
331 "type": "tool-call-delta",
332 "delta": {
333 "message": {
334 "tool_calls": {
335 "function": {
336 "arguments": "\"location\""
337 }
338 }
339 }
340 }
341 });
342
343 let event: StreamingEvent = serde_json::from_value(json).unwrap();
344 match event {
345 StreamingEvent::ToolCallDelta { delta } => {
346 assert!(delta.is_some());
347 let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
348 let function = tool_call.function.unwrap();
349 assert_eq!(function.arguments, Some("\"location\"".to_string()));
350 }
351 _ => panic!("Expected ToolCallDelta"),
352 }
353 }
354
355 #[test]
356 fn test_tool_call_end_deserialization() {
357 let json = json!({
358 "type": "tool-call-end"
359 });
360
361 let event: StreamingEvent = serde_json::from_value(json).unwrap();
362 match event {
363 StreamingEvent::ToolCallEnd => {
364 }
366 _ => panic!("Expected ToolCallEnd"),
367 }
368 }
369
370 #[test]
371 fn test_message_end_with_usage_deserialization() {
372 let json = json!({
373 "type": "message-end",
374 "delta": {
375 "usage": {
376 "tokens": {
377 "input_tokens": 100,
378 "output_tokens": 50
379 }
380 }
381 }
382 });
383
384 let event: StreamingEvent = serde_json::from_value(json).unwrap();
385 match event {
386 StreamingEvent::MessageEnd { delta } => {
387 assert!(delta.is_some());
388 let usage = delta.unwrap().usage.unwrap();
389 let tokens = usage.tokens.unwrap();
390 assert_eq!(tokens.input_tokens, Some(100.0));
391 assert_eq!(tokens.output_tokens, Some(50.0));
392 }
393 _ => panic!("Expected MessageEnd"),
394 }
395 }
396
397 #[test]
398 fn test_streaming_event_order() {
399 let events = vec![
401 json!({"type": "message-start"}),
402 json!({"type": "content-start"}),
403 json!({
404 "type": "content-delta",
405 "delta": {
406 "message": {
407 "content": {
408 "text": "Sure, "
409 }
410 }
411 }
412 }),
413 json!({
414 "type": "content-delta",
415 "delta": {
416 "message": {
417 "content": {
418 "text": "I can help with that."
419 }
420 }
421 }
422 }),
423 json!({"type": "content-end"}),
424 json!({"type": "tool-plan"}),
425 json!({
426 "type": "tool-call-start",
427 "delta": {
428 "message": {
429 "tool_calls": {
430 "id": "call_abc",
431 "function": {
432 "name": "search",
433 "arguments": ""
434 }
435 }
436 }
437 }
438 }),
439 json!({
440 "type": "tool-call-delta",
441 "delta": {
442 "message": {
443 "tool_calls": {
444 "function": {
445 "arguments": "{\"query\":"
446 }
447 }
448 }
449 }
450 }),
451 json!({
452 "type": "tool-call-delta",
453 "delta": {
454 "message": {
455 "tool_calls": {
456 "function": {
457 "arguments": "\"Rust\"}"
458 }
459 }
460 }
461 }
462 }),
463 json!({"type": "tool-call-end"}),
464 json!({
465 "type": "message-end",
466 "delta": {
467 "usage": {
468 "tokens": {
469 "input_tokens": 50,
470 "output_tokens": 25
471 }
472 }
473 }
474 }),
475 ];
476
477 for (i, event_json) in events.iter().enumerate() {
478 let result = serde_json::from_value::<StreamingEvent>(event_json.clone());
479 assert!(
480 result.is_ok(),
481 "Failed to deserialize event at index {}: {:?}",
482 i,
483 result.err()
484 );
485 }
486 }
487}