rig/providers/cohere/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::providers::cohere::CompletionModel;
3use crate::providers::cohere::completion::{
4    AssistantContent, Message, ToolCall, ToolCallFunction, ToolType, Usage,
5};
6use crate::streaming::RawStreamingChoice;
7use crate::telemetry::SpanCombinator;
8use crate::{json_utils, streaming};
9use async_stream::stream;
10use futures::StreamExt;
11use reqwest_eventsource::Event;
12use serde::{Deserialize, Serialize};
13use tracing::info_span;
14use tracing_futures::Instrument;
15
16#[derive(Debug, Deserialize)]
17#[serde(rename_all = "kebab-case", tag = "type")]
18enum StreamingEvent {
19    MessageStart,
20    ContentStart,
21    ContentDelta { delta: Option<Delta> },
22    ContentEnd,
23    ToolPlan,
24    ToolCallStart { delta: Option<Delta> },
25    ToolCallDelta { delta: Option<Delta> },
26    ToolCallEnd,
27    MessageEnd { delta: Option<MessageEndDelta> },
28}
29
30#[derive(Debug, Deserialize)]
31struct MessageContentDelta {
32    text: Option<String>,
33}
34
35#[derive(Debug, Deserialize)]
36struct MessageToolFunctionDelta {
37    name: Option<String>,
38    arguments: Option<String>,
39}
40
41#[derive(Debug, Deserialize)]
42struct MessageToolCallDelta {
43    id: Option<String>,
44    function: Option<MessageToolFunctionDelta>,
45}
46
47#[derive(Debug, Deserialize)]
48struct MessageDelta {
49    content: Option<MessageContentDelta>,
50    tool_calls: Option<MessageToolCallDelta>,
51}
52
53#[derive(Debug, Deserialize)]
54struct Delta {
55    message: Option<MessageDelta>,
56}
57
58#[derive(Debug, Deserialize)]
59struct MessageEndDelta {
60    usage: Option<Usage>,
61}
62
63#[derive(Clone, Serialize, Deserialize)]
64pub struct StreamingCompletionResponse {
65    pub usage: Option<Usage>,
66}
67
68impl GetTokenUsage for StreamingCompletionResponse {
69    fn token_usage(&self) -> Option<crate::completion::Usage> {
70        let tokens = self
71            .usage
72            .clone()
73            .and_then(|response| response.tokens)
74            .map(|tokens| {
75                (
76                    tokens.input_tokens.map(|x| x as u64),
77                    tokens.output_tokens.map(|y| y as u64),
78                )
79            });
80        let Some((Some(input), Some(output))) = tokens else {
81            return None;
82        };
83        let mut usage = crate::completion::Usage::new();
84        usage.input_tokens = input;
85        usage.output_tokens = output;
86        usage.total_tokens = input + output;
87
88        Some(usage)
89    }
90}
91
92impl CompletionModel<reqwest::Client> {
93    pub(crate) async fn stream(
94        &self,
95        request: CompletionRequest,
96    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
97    {
98        let request = self.create_completion_request(request)?;
99        let span = if tracing::Span::current().is_disabled() {
100            info_span!(
101                target: "rig::completions",
102                "chat_streaming",
103                gen_ai.operation.name = "chat_streaming",
104                gen_ai.provider.name = "cohere",
105                gen_ai.request.model = self.model,
106                gen_ai.response.id = tracing::field::Empty,
107                gen_ai.response.model = self.model,
108                gen_ai.usage.output_tokens = tracing::field::Empty,
109                gen_ai.usage.input_tokens = tracing::field::Empty,
110                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
111                gen_ai.output.messages = tracing::field::Empty,
112            )
113        } else {
114            tracing::Span::current()
115        };
116
117        let request = json_utils::merge(request, serde_json::json!({"stream": true}));
118
119        tracing::debug!(
120            "Cohere streaming completion input: {}",
121            serde_json::to_string_pretty(&request)?
122        );
123
124        let req = self.client.client().post("/v2/chat").json(&request);
125
126        let mut event_source = self
127            .client
128            .eventsource(req)
129            .await
130            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
131
132        let stream = stream! {
133            let mut current_tool_call: Option<(String, String, String)> = None;
134            let mut text_response = String::new();
135            let mut tool_calls = Vec::new();
136
137            while let Some(event_result) = event_source.next().await {
138                match event_result {
139                    Ok(Event::Open) => {
140                        tracing::trace!("SSE connection opened");
141                        continue;
142                    }
143
144                    Ok(Event::Message(message)) => {
145                        let data_str = message.data.trim();
146                        if data_str.is_empty() || data_str == "[DONE]" {
147                            continue;
148                        }
149
150                        let event: StreamingEvent = match serde_json::from_str(data_str) {
151                            Ok(ev) => ev,
152                            Err(_) => {
153                                tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
154                                continue;
155                            }
156                        };
157
158                        match event {
159                            StreamingEvent::ContentDelta { delta: Some(delta) } => {
160                                let Some(message) = &delta.message else { continue; };
161                                let Some(content) = &message.content else { continue; };
162                                let Some(text) = &content.text else { continue; };
163
164                                text_response += text;
165
166                                yield Ok(RawStreamingChoice::Message(text.clone()));
167                            },
168
169                            StreamingEvent::MessageEnd { delta: Some(delta) } => {
170                                let message = Message::Assistant {
171                                    tool_calls: tool_calls.clone(),
172                                    content: vec![AssistantContent::Text { text: text_response.clone() }],
173                                    tool_plan: None,
174                                    citations: vec![]
175                                };
176
177                                let span = tracing::Span::current();
178                                span.record_token_usage(&delta.usage);
179                                span.record_model_output(&vec![message]);
180
181                                yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
182                                    usage: delta.usage.clone()
183                                }));
184                            },
185
186                            StreamingEvent::ToolCallStart { delta: Some(delta) } => {
187                                let Some(message) = &delta.message else { continue; };
188                                let Some(tool_calls) = &message.tool_calls else { continue; };
189                                let Some(id) = tool_calls.id.clone() else { continue; };
190                                let Some(function) = &tool_calls.function else { continue; };
191                                let Some(name) = function.name.clone() else { continue; };
192                                let Some(arguments) = function.arguments.clone() else { continue; };
193
194                                current_tool_call = Some((id, name, arguments));
195                            },
196
197                            StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
198                                let Some(message) = &delta.message else { continue; };
199                                let Some(tool_calls) = &message.tool_calls else { continue; };
200                                let Some(function) = &tool_calls.function else { continue; };
201                                let Some(arguments) = function.arguments.clone() else { continue; };
202
203                                let Some(tc) = current_tool_call.clone() else { continue; };
204                                current_tool_call = Some((tc.0, tc.1, format!("{}{}", tc.2, arguments)));
205                            },
206
207                            StreamingEvent::ToolCallEnd => {
208                                let Some(tc) = current_tool_call.clone() else { continue; };
209                                let Ok(args) = serde_json::from_str::<serde_json::Value>(&tc.2) else { continue; };
210
211                                tool_calls.push(ToolCall {
212                                    id: Some(tc.0.clone()),
213                                    r#type: Some(ToolType::Function),
214                                    function: Some(ToolCallFunction {
215                                        name: tc.1.clone(),
216                                        arguments: args.clone()
217                                    })
218                                });
219
220                                yield Ok(RawStreamingChoice::ToolCall {
221                                    id: tc.0,
222                                    name: tc.1,
223                                    arguments: args,
224                                    call_id: None
225                                });
226
227                                current_tool_call = None;
228                            },
229
230                            _ => {}
231                        }
232                    },
233
234                    Err(reqwest_eventsource::Error::StreamEnded) => break,
235
236                    Err(err) => {
237                        tracing::error!(?err, "SSE error");
238                        yield Err(CompletionError::ResponseError(err.to_string()));
239                        break;
240                    }
241                }
242            }
243
244            event_source.close();
245        }.instrument(span);
246
247        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
248            stream,
249        )))
250    }
251}