rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use reqwest_eventsource::{Event, RequestBuilderExt};
4use serde::{Deserialize, Serialize};
5
6use super::completion::{
7    CompletionModel, create_request_body,
8    gemini_api_types::{ContentCandidate, Part, PartKind},
9};
10use crate::{
11    completion::{CompletionError, CompletionRequest, GetTokenUsage},
12    streaming::{self},
13};
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18    pub total_token_count: i32,
19}
20
21#[derive(Debug, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct StreamGenerateContentResponse {
24    /// Candidate responses from the model.
25    pub candidates: Vec<ContentCandidate>,
26    pub model_version: Option<String>,
27    pub usage_metadata: Option<PartialUsage>,
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct StreamingCompletionResponse {
32    pub usage_metadata: PartialUsage,
33}
34
35impl GetTokenUsage for StreamingCompletionResponse {
36    fn token_usage(&self) -> Option<crate::completion::Usage> {
37        let mut usage = crate::completion::Usage::new();
38        usage.total_tokens = self.usage_metadata.total_token_count as u64;
39        Some(usage)
40    }
41}
42
43impl CompletionModel {
44    pub(crate) async fn stream(
45        &self,
46        completion_request: CompletionRequest,
47    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
48    {
49        let request = create_request_body(completion_request)?;
50
51        tracing::debug!(
52            "Sending completion request to Gemini API {}",
53            serde_json::to_string_pretty(&request)?
54        );
55
56        // Build the request with proper headers for SSE
57        let mut event_source = self
58            .client
59            .post_sse(&format!(
60                "/v1beta/models/{}:streamGenerateContent",
61                self.model
62            ))
63            .json(&request)
64            .eventsource()
65            .expect("Cloning request must always succeed");
66
67        let stream = Box::pin(stream! {
68            while let Some(event_result) = event_source.next().await {
69                match event_result {
70                    Ok(Event::Open) => {
71                        tracing::trace!("SSE connection opened");
72                        continue;
73                    }
74                    Ok(Event::Message(message)) => {
75                        // Skip heartbeat messages or empty data
76                        if message.data.trim().is_empty() {
77                            continue;
78                        }
79
80                        let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
81                            Ok(d) => d,
82                            Err(error) => {
83                                tracing::error!(?error, message = message.data, "Failed to parse SSE message");
84                                continue;
85                            }
86                        };
87
88                        // Process the response data
89                        let Some(choice) = data.candidates.first() else {
90                            tracing::debug!("There is no content candidate");
91                            continue;
92                        };
93
94                        match choice.content.parts.first() {
95                            Some(Part {
96                                part: PartKind::Text(text),
97                                thought: Some(true),
98                                ..
99                            }) => {
100                                yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None });
101                            },
102                            Some(Part {
103                                part: PartKind::Text(text),
104                                ..
105                            }) => {
106                                yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
107                            },
108                            Some(Part {
109                                part: PartKind::FunctionCall(function_call),
110                                ..
111                            }) => {
112                                yield Ok(streaming::RawStreamingChoice::ToolCall {
113                                    name: function_call.name.clone(),
114                                    id: function_call.name.clone(),
115                                    arguments: function_call.args.clone(),
116                                    call_id: None
117                                });
118                            },
119                            Some(part) => {
120                                tracing::warn!(?part, "Unsupported response type with streaming");
121                            }
122                            None => tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content"),
123                        }
124
125                        // Check if this is the final response
126                        if choice.finish_reason.is_some() {
127                            let usage = data.usage_metadata
128                                            .map(|u| u.total_token_count)
129                                            .unwrap_or(0);
130
131                            yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
132                                usage_metadata: PartialUsage {
133                                    total_token_count: usage,
134                                }
135                            }));
136                            break;
137                        }
138                    }
139                    Err(reqwest_eventsource::Error::StreamEnded) => {
140                        break;
141                    }
142                    Err(error) => {
143                        tracing::error!(?error, "SSE error");
144                        yield Err(CompletionError::ResponseError(error.to_string()));
145                        break;
146                    }
147                }
148            }
149
150            // Ensure event source is closed when stream ends
151            event_source.close();
152        });
153
154        Ok(streaming::StreamingCompletionResponse::stream(stream))
155    }
156}