rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use reqwest_eventsource::{Event, RequestBuilderExt};
4use serde::{Deserialize, Serialize};
5
6use super::completion::{
7    CompletionModel, create_request_body,
8    gemini_api_types::{ContentCandidate, Part, PartKind},
9};
10use crate::{
11    completion::{CompletionError, CompletionRequest, GetTokenUsage},
12    streaming::{self},
13};
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18    pub total_token_count: i32,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub cached_content_token_count: Option<i32>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub candidates_token_count: Option<i32>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub thoughts_token_count: Option<i32>,
25    pub prompt_token_count: i32,
26}
27
28#[derive(Debug, Deserialize)]
29#[serde(rename_all = "camelCase")]
30pub struct StreamGenerateContentResponse {
31    /// Candidate responses from the model.
32    pub candidates: Vec<ContentCandidate>,
33    pub model_version: Option<String>,
34    pub usage_metadata: Option<PartialUsage>,
35}
36
37#[derive(Clone, Debug, Serialize, Deserialize)]
38pub struct StreamingCompletionResponse {
39    pub usage_metadata: PartialUsage,
40}
41
42impl GetTokenUsage for StreamingCompletionResponse {
43    fn token_usage(&self) -> Option<crate::completion::Usage> {
44        let mut usage = crate::completion::Usage::new();
45        usage.total_tokens = self.usage_metadata.total_token_count as u64;
46        usage.output_tokens = self
47            .usage_metadata
48            .candidates_token_count
49            .map(|x| x as u64)
50            .unwrap_or(0);
51        usage.input_tokens = self.usage_metadata.prompt_token_count as u64;
52        Some(usage)
53    }
54}
55
56impl CompletionModel {
57    pub(crate) async fn stream(
58        &self,
59        completion_request: CompletionRequest,
60    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
61    {
62        let request = create_request_body(completion_request)?;
63
64        tracing::debug!(
65            "Sending completion request to Gemini API {}",
66            serde_json::to_string_pretty(&request)?
67        );
68
69        // Build the request with proper headers for SSE
70        let mut event_source = self
71            .client
72            .post_sse(&format!(
73                "/v1beta/models/{}:streamGenerateContent",
74                self.model
75            ))
76            .json(&request)
77            .eventsource()
78            .expect("Cloning request must always succeed");
79
80        let stream = Box::pin(stream! {
81            while let Some(event_result) = event_source.next().await {
82                match event_result {
83                    Ok(Event::Open) => {
84                        tracing::trace!("SSE connection opened");
85                        continue;
86                    }
87                    Ok(Event::Message(message)) => {
88                        // Skip heartbeat messages or empty data
89                        if message.data.trim().is_empty() {
90                            continue;
91                        }
92
93                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
94                            Ok(d) => d,
95                            Err(error) => {
96                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
97                                continue;
98                            }
99                        };
100
101                        // Process the response data
102                        let Some(choice) = data.candidates.first() else {
103                            tracing::debug!("There is no content candidate");
104                            continue;
105                        };
106
107                        match choice.content.parts.first() {
108                            Some(Part {
109                                part: PartKind::Text(text),
110                                thought: Some(true),
111                                ..
112                            }) => {
113                                yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None });
114                            },
115                            Some(Part {
116                                part: PartKind::Text(text),
117                                ..
118                            }) => {
119                                yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
120                            },
121                            Some(Part {
122                                part: PartKind::FunctionCall(function_call),
123                                ..
124                            }) => {
125                                yield Ok(streaming::RawStreamingChoice::ToolCall {
126                                    name: function_call.name.clone(),
127                                    id: function_call.name.clone(),
128                                    arguments: function_call.args.clone(),
129                                    call_id: None
130                                });
131                            },
132                            Some(part) => {
133                                tracing::warn!(?part, "Unsupported response type with streaming");
134                            }
135                            None => tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content"),
136                        }
137
138                        // Check if this is the final response
139                        if choice.finish_reason.is_some() {
140                            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
141                                usage_metadata: data.usage_metadata.unwrap_or_default()
142                            }));
143                            break;
144                        }
145                    }
146                    Err(reqwest_eventsource::Error::StreamEnded) => {
147                        break;
148                    }
149                    Err(error) => {
150                        tracing::error!(?error, "SSE error");
151                        yield Err(CompletionError::ResponseError(error.to_string()));
152                        break;
153                    }
154                }
155            }
156
157            // Ensure event source is closed when stream ends
158            event_source.close();
159        });
160
161        Ok(streaming::StreamingCompletionResponse::stream(stream))
162    }
163}