rig/providers/cohere/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest};
2use crate::providers::cohere::CompletionModel;
3use crate::providers::cohere::completion::Usage;
4use crate::streaming::RawStreamingChoice;
5use crate::{json_utils, streaming};
6use async_stream::stream;
7use futures::StreamExt;
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10
11#[derive(Debug, Deserialize)]
12#[serde(rename_all = "kebab-case", tag = "type")]
13enum StreamingEvent {
14    MessageStart,
15    ContentStart,
16    ContentDelta { delta: Option<Delta> },
17    ContentEnd,
18    ToolPlan,
19    ToolCallStart { delta: Option<Delta> },
20    ToolCallDelta { delta: Option<Delta> },
21    ToolCallEnd,
22    MessageEnd { delta: Option<MessageEndDelta> },
23}
24
25#[derive(Debug, Deserialize)]
26struct MessageContentDelta {
27    text: Option<String>,
28}
29
30#[derive(Debug, Deserialize)]
31struct MessageToolFunctionDelta {
32    name: Option<String>,
33    arguments: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolCallDelta {
38    id: Option<String>,
39    function: Option<MessageToolFunctionDelta>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageDelta {
44    content: Option<MessageContentDelta>,
45    tool_calls: Option<MessageToolCallDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct Delta {
50    message: Option<MessageDelta>,
51}
52
53#[derive(Debug, Deserialize)]
54struct MessageEndDelta {
55    usage: Option<Usage>,
56}
57
58#[derive(Clone, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60    pub usage: Option<Usage>,
61}
62
63impl CompletionModel {
64    pub(crate) async fn stream(
65        &self,
66        request: CompletionRequest,
67    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
68    {
69        let request = self.create_completion_request(request)?;
70        let request = json_utils::merge(request, json!({"stream": true}));
71
72        tracing::debug!(
73            "Cohere request: {}",
74            serde_json::to_string_pretty(&request)?
75        );
76
77        let response = self.client.post("/v2/chat").json(&request).send().await?;
78
79        if !response.status().is_success() {
80            return Err(CompletionError::ProviderError(format!(
81                "{}: {}",
82                response.status(),
83                response.text().await?
84            )));
85        }
86
87        let stream = Box::pin(stream! {
88            let mut stream = response.bytes_stream();
89            let mut current_tool_call: Option<(String, String, String)> = None;
90
91            while let Some(chunk_result) = stream.next().await {
92               let chunk = match chunk_result {
93                    Ok(c) => c,
94                    Err(e) => {
95                        yield Err(CompletionError::from(e));
96                        break;
97                    }
98                };
99
100               let text = match String::from_utf8(chunk.to_vec()) {
101                    Ok(t) => t,
102                    Err(e) => {
103                        yield Err(CompletionError::ResponseError(e.to_string()));
104                        break;
105                    }
106               };
107
108                for line in text.lines() {
109
110                    let Some(line) = line.strip_prefix("data: ") else {
111                        continue;
112                    };
113
114                    let event = {
115                       let result = serde_json::from_str::<StreamingEvent>(line);
116
117                       let Ok(event) = result else {
118                           continue;
119                       };
120
121                        event
122                    };
123
124                    match event {
125                        StreamingEvent::ContentDelta { delta: Some(delta) } => {
126                            let Some(message) = &delta.message else { continue; };
127                            let Some(content) = &message.content else { continue; };
128                            let Some(text) = &content.text else { continue; };
129
130                            yield Ok(RawStreamingChoice::Message(text.clone()));
131                        },
132                        StreamingEvent::MessageEnd {delta: Some(delta)} => {
133                            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
134                                usage: delta.usage.clone()
135                            }));
136                        },
137                        StreamingEvent::ToolCallStart { delta: Some(delta)} => {
138                            // Skip the delta if there's any missing information,
139                            // though this *should* all be present
140                            let Some(message) = &delta.message else { continue; };
141                            let Some(tool_calls) = &message.tool_calls else { continue; };
142                            let Some(id) = tool_calls.id.clone() else { continue; };
143                            let Some(function) = &tool_calls.function else { continue; };
144                            let Some(name) = function.name.clone() else { continue; };
145                            let Some(arguments) = function.arguments.clone() else { continue; };
146
147                            current_tool_call = Some((id, name, arguments));
148                        },
149                        StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
150                            // Skip the delta if there's any missing information,
151                            // though this *should* all be present
152                            let Some(message) = &delta.message else { continue; };
153                            let Some(tool_calls) = &message.tool_calls else { continue; };
154                            let Some(function) = &tool_calls.function else { continue; };
155                            let Some(arguments) = function.arguments.clone() else { continue; };
156
157                            if let Some(tc) = current_tool_call.clone() {
158                                current_tool_call = Some((
159                                    tc.0,
160                                    tc.1,
161                                    format!("{}{}", tc.2, arguments)
162                                ));
163                            };
164                        },
165                        StreamingEvent::ToolCallEnd => {
166                            let Some(tc) = current_tool_call.clone() else { continue; };
167
168                            let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
169
170                            yield Ok(RawStreamingChoice::ToolCall {
171                                id: tc.0,
172                                name: tc.1,
173                                arguments: args,
174                                call_id: None
175                            });
176
177                            current_tool_call = None;
178                        },
179                        _ => {}
180                    };
181                }
182            }
183        });
184
185        Ok(streaming::StreamingCompletionResponse::stream(stream))
186    }
187}