rig/providers/cohere/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::providers::cohere::CompletionModel;
3use crate::providers::cohere::completion::{
4 AssistantContent, Message, ToolCall, ToolCallFunction, ToolType, Usage,
5};
6use crate::streaming::RawStreamingChoice;
7use crate::telemetry::SpanCombinator;
8use crate::{json_utils, streaming};
9use async_stream::stream;
10use futures::StreamExt;
11use reqwest_eventsource::Event;
12use serde::{Deserialize, Serialize};
13use tracing::info_span;
14use tracing_futures::Instrument;
15
16#[derive(Debug, Deserialize)]
17#[serde(rename_all = "kebab-case", tag = "type")]
18enum StreamingEvent {
19 MessageStart,
20 ContentStart,
21 ContentDelta { delta: Option<Delta> },
22 ContentEnd,
23 ToolPlan,
24 ToolCallStart { delta: Option<Delta> },
25 ToolCallDelta { delta: Option<Delta> },
26 ToolCallEnd,
27 MessageEnd { delta: Option<MessageEndDelta> },
28}
29
30#[derive(Debug, Deserialize)]
31struct MessageContentDelta {
32 text: Option<String>,
33}
34
35#[derive(Debug, Deserialize)]
36struct MessageToolFunctionDelta {
37 name: Option<String>,
38 arguments: Option<String>,
39}
40
41#[derive(Debug, Deserialize)]
42struct MessageToolCallDelta {
43 id: Option<String>,
44 function: Option<MessageToolFunctionDelta>,
45}
46
47#[derive(Debug, Deserialize)]
48struct MessageDelta {
49 content: Option<MessageContentDelta>,
50 tool_calls: Option<MessageToolCallDelta>,
51}
52
53#[derive(Debug, Deserialize)]
54struct Delta {
55 message: Option<MessageDelta>,
56}
57
58#[derive(Debug, Deserialize)]
59struct MessageEndDelta {
60 usage: Option<Usage>,
61}
62
63#[derive(Clone, Serialize, Deserialize)]
64pub struct StreamingCompletionResponse {
65 pub usage: Option<Usage>,
66}
67
68impl GetTokenUsage for StreamingCompletionResponse {
69 fn token_usage(&self) -> Option<crate::completion::Usage> {
70 let tokens = self
71 .usage
72 .clone()
73 .and_then(|response| response.tokens)
74 .map(|tokens| {
75 (
76 tokens.input_tokens.map(|x| x as u64),
77 tokens.output_tokens.map(|y| y as u64),
78 )
79 });
80 let Some((Some(input), Some(output))) = tokens else {
81 return None;
82 };
83 let mut usage = crate::completion::Usage::new();
84 usage.input_tokens = input;
85 usage.output_tokens = output;
86 usage.total_tokens = input + output;
87
88 Some(usage)
89 }
90}
91
92impl CompletionModel<reqwest::Client> {
93 pub(crate) async fn stream(
94 &self,
95 request: CompletionRequest,
96 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
97 {
98 let request = self.create_completion_request(request)?;
99 let span = if tracing::Span::current().is_disabled() {
100 info_span!(
101 target: "rig::completions",
102 "chat_streaming",
103 gen_ai.operation.name = "chat_streaming",
104 gen_ai.provider.name = "cohere",
105 gen_ai.request.model = self.model,
106 gen_ai.response.id = tracing::field::Empty,
107 gen_ai.response.model = self.model,
108 gen_ai.usage.output_tokens = tracing::field::Empty,
109 gen_ai.usage.input_tokens = tracing::field::Empty,
110 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
111 gen_ai.output.messages = tracing::field::Empty,
112 )
113 } else {
114 tracing::Span::current()
115 };
116
117 let request = json_utils::merge(request, serde_json::json!({"stream": true}));
118
119 tracing::debug!(
120 "Cohere streaming completion input: {}",
121 serde_json::to_string_pretty(&request)?
122 );
123
124 let req = self.client.client().post("/v2/chat").json(&request);
125
126 let mut event_source = self
127 .client
128 .eventsource(req)
129 .await
130 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
131
132 let stream = stream! {
133 let mut current_tool_call: Option<(String, String, String)> = None;
134 let mut text_response = String::new();
135 let mut tool_calls = Vec::new();
136
137 while let Some(event_result) = event_source.next().await {
138 match event_result {
139 Ok(Event::Open) => {
140 tracing::trace!("SSE connection opened");
141 continue;
142 }
143
144 Ok(Event::Message(message)) => {
145 let data_str = message.data.trim();
146 if data_str.is_empty() || data_str == "[DONE]" {
147 continue;
148 }
149
150 let event: StreamingEvent = match serde_json::from_str(data_str) {
151 Ok(ev) => ev,
152 Err(_) => {
153 tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
154 continue;
155 }
156 };
157
158 match event {
159 StreamingEvent::ContentDelta { delta: Some(delta) } => {
160 let Some(message) = &delta.message else { continue; };
161 let Some(content) = &message.content else { continue; };
162 let Some(text) = &content.text else { continue; };
163
164 text_response += text;
165
166 yield Ok(RawStreamingChoice::Message(text.clone()));
167 },
168
169 StreamingEvent::MessageEnd { delta: Some(delta) } => {
170 let message = Message::Assistant {
171 tool_calls: tool_calls.clone(),
172 content: vec![AssistantContent::Text { text: text_response.clone() }],
173 tool_plan: None,
174 citations: vec![]
175 };
176
177 let span = tracing::Span::current();
178 span.record_token_usage(&delta.usage);
179 span.record_model_output(&vec![message]);
180
181 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
182 usage: delta.usage.clone()
183 }));
184 },
185
186 StreamingEvent::ToolCallStart { delta: Some(delta) } => {
187 let Some(message) = &delta.message else { continue; };
188 let Some(tool_calls) = &message.tool_calls else { continue; };
189 let Some(id) = tool_calls.id.clone() else { continue; };
190 let Some(function) = &tool_calls.function else { continue; };
191 let Some(name) = function.name.clone() else { continue; };
192 let Some(arguments) = function.arguments.clone() else { continue; };
193
194 current_tool_call = Some((id, name, arguments));
195 },
196
197 StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
198 let Some(message) = &delta.message else { continue; };
199 let Some(tool_calls) = &message.tool_calls else { continue; };
200 let Some(function) = &tool_calls.function else { continue; };
201 let Some(arguments) = function.arguments.clone() else { continue; };
202
203 let Some(tc) = current_tool_call.clone() else { continue; };
204 current_tool_call = Some((tc.0, tc.1, format!("{}{}", tc.2, arguments)));
205 },
206
207 StreamingEvent::ToolCallEnd => {
208 let Some(tc) = current_tool_call.clone() else { continue; };
209 let Ok(args) = serde_json::from_str::<serde_json::Value>(&tc.2) else { continue; };
210
211 tool_calls.push(ToolCall {
212 id: Some(tc.0.clone()),
213 r#type: Some(ToolType::Function),
214 function: Some(ToolCallFunction {
215 name: tc.1.clone(),
216 arguments: args.clone()
217 })
218 });
219
220 yield Ok(RawStreamingChoice::ToolCall {
221 id: tc.0,
222 name: tc.1,
223 arguments: args,
224 call_id: None
225 });
226
227 current_tool_call = None;
228 },
229
230 _ => {}
231 }
232 },
233
234 Err(reqwest_eventsource::Error::StreamEnded) => break,
235
236 Err(err) => {
237 tracing::error!(?err, "SSE error");
238 yield Err(CompletionError::ResponseError(err.to_string()));
239 break;
240 }
241 }
242 }
243
244 event_source.close();
245 }.instrument(span);
246
247 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
248 stream,
249 )))
250 }
251}