rig/providers/cohere/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::providers::cohere::CompletionModel;
3use crate::providers::cohere::completion::Usage;
4use crate::streaming::RawStreamingChoice;
5use crate::{json_utils, streaming};
6use async_stream::stream;
7use futures::StreamExt;
8use reqwest_eventsource::{Event, RequestBuilderExt};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Deserialize)]
12#[serde(rename_all = "kebab-case", tag = "type")]
13enum StreamingEvent {
14    MessageStart,
15    ContentStart,
16    ContentDelta { delta: Option<Delta> },
17    ContentEnd,
18    ToolPlan,
19    ToolCallStart { delta: Option<Delta> },
20    ToolCallDelta { delta: Option<Delta> },
21    ToolCallEnd,
22    MessageEnd { delta: Option<MessageEndDelta> },
23}
24
25#[derive(Debug, Deserialize)]
26struct MessageContentDelta {
27    text: Option<String>,
28}
29
30#[derive(Debug, Deserialize)]
31struct MessageToolFunctionDelta {
32    name: Option<String>,
33    arguments: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolCallDelta {
38    id: Option<String>,
39    function: Option<MessageToolFunctionDelta>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageDelta {
44    content: Option<MessageContentDelta>,
45    tool_calls: Option<MessageToolCallDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct Delta {
50    message: Option<MessageDelta>,
51}
52
53#[derive(Debug, Deserialize)]
54struct MessageEndDelta {
55    usage: Option<Usage>,
56}
57
58#[derive(Clone, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60    pub usage: Option<Usage>,
61}
62
63impl GetTokenUsage for StreamingCompletionResponse {
64    fn token_usage(&self) -> Option<crate::completion::Usage> {
65        let tokens = self
66            .usage
67            .clone()
68            .and_then(|response| response.tokens)
69            .map(|tokens| {
70                (
71                    tokens.input_tokens.map(|x| x as u64),
72                    tokens.output_tokens.map(|y| y as u64),
73                )
74            });
75        let Some((Some(input), Some(output))) = tokens else {
76            return None;
77        };
78        let mut usage = crate::completion::Usage::new();
79        usage.input_tokens = input;
80        usage.output_tokens = output;
81        usage.total_tokens = input + output;
82
83        Some(usage)
84    }
85}
86
87impl CompletionModel {
88    pub(crate) async fn stream(
89        &self,
90        request: CompletionRequest,
91    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
92    {
93        let request = self.create_completion_request(request)?;
94        let request = json_utils::merge(request, serde_json::json!({"stream": true}));
95
96        tracing::debug!(
97            "Cohere request: {}",
98            serde_json::to_string_pretty(&request)?
99        );
100
101        let mut event_source = self
102            .client
103            .post("/v2/chat")
104            .json(&request)
105            .eventsource()
106            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
107
108        let stream = Box::pin(stream! {
109            let mut current_tool_call: Option<(String, String, String)> = None;
110
111            while let Some(event_result) = event_source.next().await {
112                match event_result {
113                    Ok(Event::Open) => {
114                        tracing::trace!("SSE connection opened");
115                        continue;
116                    }
117
118                    Ok(Event::Message(message)) => {
119                        let data_str = message.data.trim();
120                        if data_str.is_empty() || data_str == "[DONE]" {
121                            continue;
122                        }
123
124                        let event: StreamingEvent = match serde_json::from_str(data_str) {
125                            Ok(ev) => ev,
126                            Err(_) => {
127                                tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
128                                continue;
129                            }
130                        };
131
132                        match event {
133                            StreamingEvent::ContentDelta { delta: Some(delta) } => {
134                                let Some(message) = &delta.message else { continue; };
135                                let Some(content) = &message.content else { continue; };
136                                let Some(text) = &content.text else { continue; };
137
138                                yield Ok(RawStreamingChoice::Message(text.clone()));
139                            },
140
141                            StreamingEvent::MessageEnd { delta: Some(delta) } => {
142                                yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
143                                    usage: delta.usage.clone()
144                                }));
145                            },
146
147                            StreamingEvent::ToolCallStart { delta: Some(delta) } => {
148                                let Some(message) = &delta.message else { continue; };
149                                let Some(tool_calls) = &message.tool_calls else { continue; };
150                                let Some(id) = tool_calls.id.clone() else { continue; };
151                                let Some(function) = &tool_calls.function else { continue; };
152                                let Some(name) = function.name.clone() else { continue; };
153                                let Some(arguments) = function.arguments.clone() else { continue; };
154
155                                current_tool_call = Some((id, name, arguments));
156                            },
157
158                            StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
159                                let Some(message) = &delta.message else { continue; };
160                                let Some(tool_calls) = &message.tool_calls else { continue; };
161                                let Some(function) = &tool_calls.function else { continue; };
162                                let Some(arguments) = function.arguments.clone() else { continue; };
163
164                                let Some(tc) = current_tool_call.clone() else { continue; };
165                                current_tool_call = Some((tc.0, tc.1, format!("{}{}", tc.2, arguments)));
166                            },
167
168                            StreamingEvent::ToolCallEnd => {
169                                let Some(tc) = current_tool_call.clone() else { continue; };
170                                let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
171
172                                yield Ok(RawStreamingChoice::ToolCall {
173                                    id: tc.0,
174                                    name: tc.1,
175                                    arguments: args,
176                                    call_id: None
177                                });
178
179                                current_tool_call = None;
180                            },
181
182                            _ => {}
183                        }
184                    },
185
186                    Err(reqwest_eventsource::Error::StreamEnded) => break,
187
188                    Err(err) => {
189                        tracing::error!(?err, "SSE error");
190                        yield Err(CompletionError::ResponseError(err.to_string()));
191                        break;
192                    }
193                }
194            }
195
196            event_source.close();
197        });
198
199        Ok(streaming::StreamingCompletionResponse::stream(stream))
200    }
201}