rig/providers/gemini/
streaming.rs1use async_stream::stream;
2use futures::StreamExt;
3use reqwest_eventsource::{Event, RequestBuilderExt};
4use serde::{Deserialize, Serialize};
5
6use super::completion::{
7 CompletionModel, create_request_body,
8 gemini_api_types::{ContentCandidate, Part, PartKind},
9};
10use crate::{
11 completion::{CompletionError, CompletionRequest, GetTokenUsage},
12 streaming::{self},
13};
14
15#[derive(Debug, Deserialize, Serialize, Default, Clone)]
16#[serde(rename_all = "camelCase")]
17pub struct PartialUsage {
18 pub total_token_count: i32,
19}
20
21#[derive(Debug, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct StreamGenerateContentResponse {
24 pub candidates: Vec<ContentCandidate>,
26 pub model_version: Option<String>,
27 pub usage_metadata: Option<PartialUsage>,
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
31pub struct StreamingCompletionResponse {
32 pub usage_metadata: PartialUsage,
33}
34
35impl GetTokenUsage for StreamingCompletionResponse {
36 fn token_usage(&self) -> Option<crate::completion::Usage> {
37 let mut usage = crate::completion::Usage::new();
38 usage.total_tokens = self.usage_metadata.total_token_count as u64;
39 Some(usage)
40 }
41}
42
43impl CompletionModel {
44 pub(crate) async fn stream(
45 &self,
46 completion_request: CompletionRequest,
47 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
48 {
49 let request = create_request_body(completion_request)?;
50
51 tracing::debug!(
52 "Sending completion request to Gemini API {}",
53 serde_json::to_string_pretty(&request)?
54 );
55
56 let mut event_source = self
58 .client
59 .post_sse(&format!(
60 "/v1beta/models/{}:streamGenerateContent",
61 self.model
62 ))
63 .json(&request)
64 .eventsource()
65 .expect("Cloning request must always succeed");
66
67 let stream = Box::pin(stream! {
68 while let Some(event_result) = event_source.next().await {
69 match event_result {
70 Ok(Event::Open) => {
71 tracing::trace!("SSE connection opened");
72 continue;
73 }
74 Ok(Event::Message(message)) => {
75 if message.data.trim().is_empty() {
77 continue;
78 }
79
80 let data = match serde_json::from_str::<StreamGenerateContentResponse>(&message.data) {
81 Ok(d) => d,
82 Err(error) => {
83 tracing::error!(?error, message = message.data, "Failed to parse SSE message");
84 continue;
85 }
86 };
87
88 let Some(choice) = data.candidates.first() else {
90 tracing::debug!("There is no content candidate");
91 continue;
92 };
93
94 match choice.content.parts.first() {
95 Some(Part {
96 part: PartKind::Text(text),
97 thought: Some(true),
98 ..
99 }) => {
100 yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None });
101 },
102 Some(Part {
103 part: PartKind::Text(text),
104 ..
105 }) => {
106 yield Ok(streaming::RawStreamingChoice::Message(text.clone()));
107 },
108 Some(Part {
109 part: PartKind::FunctionCall(function_call),
110 ..
111 }) => {
112 yield Ok(streaming::RawStreamingChoice::ToolCall {
113 name: function_call.name.clone(),
114 id: function_call.name.clone(),
115 arguments: function_call.args.clone(),
116 call_id: None
117 });
118 },
119 Some(part) => {
120 tracing::warn!(?part, "Unsupported response type with streaming");
121 }
122 None => tracing::trace!(reason = ?choice.finish_reason, "There is no part in the streaming content"),
123 }
124
125 if choice.finish_reason.is_some() {
127 let usage = data.usage_metadata
128 .map(|u| u.total_token_count)
129 .unwrap_or(0);
130
131 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
132 usage_metadata: PartialUsage {
133 total_token_count: usage,
134 }
135 }));
136 break;
137 }
138 }
139 Err(reqwest_eventsource::Error::StreamEnded) => {
140 break;
141 }
142 Err(error) => {
143 tracing::error!(?error, "SSE error");
144 yield Err(CompletionError::ResponseError(error.to_string()));
145 break;
146 }
147 }
148 }
149
150 event_source.close();
152 });
153
154 Ok(streaming::StreamingCompletionResponse::stream(stream))
155 }
156}