rig/providers/cohere/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest};
2use crate::providers::cohere::CompletionModel;
3use crate::providers::cohere::completion::Usage;
4use crate::streaming::RawStreamingChoice;
5use crate::{json_utils, streaming};
6use async_stream::stream;
7use futures::StreamExt;
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10
11#[derive(Debug, Deserialize)]
12#[serde(rename_all = "kebab-case", tag = "type")]
13enum StreamingEvent {
14 MessageStart,
15 ContentStart,
16 ContentDelta { delta: Option<Delta> },
17 ContentEnd,
18 ToolPlan,
19 ToolCallStart { delta: Option<Delta> },
20 ToolCallDelta { delta: Option<Delta> },
21 ToolCallEnd,
22 MessageEnd { delta: Option<MessageEndDelta> },
23}
24
25#[derive(Debug, Deserialize)]
26struct MessageContentDelta {
27 text: Option<String>,
28}
29
30#[derive(Debug, Deserialize)]
31struct MessageToolFunctionDelta {
32 name: Option<String>,
33 arguments: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolCallDelta {
38 id: Option<String>,
39 function: Option<MessageToolFunctionDelta>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageDelta {
44 content: Option<MessageContentDelta>,
45 tool_calls: Option<MessageToolCallDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct Delta {
50 message: Option<MessageDelta>,
51}
52
53#[derive(Debug, Deserialize)]
54struct MessageEndDelta {
55 usage: Option<Usage>,
56}
57
58#[derive(Clone, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60 pub usage: Option<Usage>,
61}
62
63impl CompletionModel {
64 pub(crate) async fn stream(
65 &self,
66 request: CompletionRequest,
67 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
68 {
69 let request = self.create_completion_request(request)?;
70 let request = json_utils::merge(request, json!({"stream": true}));
71
72 tracing::debug!(
73 "Cohere request: {}",
74 serde_json::to_string_pretty(&request)?
75 );
76
77 let response = self.client.post("/v2/chat").json(&request).send().await?;
78
79 if !response.status().is_success() {
80 return Err(CompletionError::ProviderError(format!(
81 "{}: {}",
82 response.status(),
83 response.text().await?
84 )));
85 }
86
87 let stream = Box::pin(stream! {
88 let mut stream = response.bytes_stream();
89 let mut current_tool_call: Option<(String, String, String)> = None;
90
91 while let Some(chunk_result) = stream.next().await {
92 let chunk = match chunk_result {
93 Ok(c) => c,
94 Err(e) => {
95 yield Err(CompletionError::from(e));
96 break;
97 }
98 };
99
100 let text = match String::from_utf8(chunk.to_vec()) {
101 Ok(t) => t,
102 Err(e) => {
103 yield Err(CompletionError::ResponseError(e.to_string()));
104 break;
105 }
106 };
107
108 for line in text.lines() {
109
110 let Some(line) = line.strip_prefix("data: ") else {
111 continue;
112 };
113
114 let event = {
115 let result = serde_json::from_str::<StreamingEvent>(line);
116
117 let Ok(event) = result else {
118 continue;
119 };
120
121 event
122 };
123
124 match event {
125 StreamingEvent::ContentDelta { delta: Some(delta) } => {
126 let Some(message) = &delta.message else { continue; };
127 let Some(content) = &message.content else { continue; };
128 let Some(text) = &content.text else { continue; };
129
130 yield Ok(RawStreamingChoice::Message(text.clone()));
131 },
132 StreamingEvent::MessageEnd {delta: Some(delta)} => {
133 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
134 usage: delta.usage.clone()
135 }));
136 },
137 StreamingEvent::ToolCallStart { delta: Some(delta)} => {
138 let Some(message) = &delta.message else { continue; };
141 let Some(tool_calls) = &message.tool_calls else { continue; };
142 let Some(id) = tool_calls.id.clone() else { continue; };
143 let Some(function) = &tool_calls.function else { continue; };
144 let Some(name) = function.name.clone() else { continue; };
145 let Some(arguments) = function.arguments.clone() else { continue; };
146
147 current_tool_call = Some((id, name, arguments));
148 },
149 StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
150 let Some(message) = &delta.message else { continue; };
153 let Some(tool_calls) = &message.tool_calls else { continue; };
154 let Some(function) = &tool_calls.function else { continue; };
155 let Some(arguments) = function.arguments.clone() else { continue; };
156
157 if let Some(tc) = current_tool_call.clone() {
158 current_tool_call = Some((
159 tc.0,
160 tc.1,
161 format!("{}{}", tc.2, arguments)
162 ));
163 };
164 },
165 StreamingEvent::ToolCallEnd => {
166 let Some(tc) = current_tool_call.clone() else { continue; };
167
168 let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
169
170 yield Ok(RawStreamingChoice::ToolCall {
171 id: tc.0,
172 name: tc.1,
173 arguments: args,
174 call_id: None
175 });
176
177 current_tool_call = None;
178 },
179 _ => {}
180 };
181 }
182 }
183 });
184
185 Ok(streaming::StreamingCompletionResponse::stream(stream))
186 }
187}