rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::Deserialize;
4
5use super::completion::{CompletionModel, create_request_body, gemini_api_types::ContentCandidate};
6use crate::{
7    completion::{CompletionError, CompletionRequest},
8    streaming::{self},
9};
10
11#[derive(Debug, Deserialize, Default, Clone)]
12#[serde(rename_all = "camelCase")]
13pub struct PartialUsage {
14    pub total_token_count: i32,
15}
16
17#[derive(Debug, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct StreamGenerateContentResponse {
20    /// Candidate responses from the model.
21    pub candidates: Vec<ContentCandidate>,
22    pub model_version: Option<String>,
23    pub usage_metadata: Option<PartialUsage>,
24}
25
26#[derive(Clone, Debug)]
27pub struct StreamingCompletionResponse {
28    pub usage_metadata: PartialUsage,
29}
30
31impl CompletionModel {
32    pub(crate) async fn stream(
33        &self,
34        completion_request: CompletionRequest,
35    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
36    {
37        let request = create_request_body(completion_request)?;
38
39        let response = self
40            .client
41            .post_sse(&format!(
42                "/v1beta/models/{}:streamGenerateContent",
43                self.model
44            ))
45            .json(&request)
46            .send()
47            .await?;
48
49        if !response.status().is_success() {
50            return Err(CompletionError::ProviderError(format!(
51                "{}: {}",
52                response.status(),
53                response.text().await?
54            )));
55        }
56
57        let stream = Box::pin(stream! {
58            let mut stream = response.bytes_stream();
59
60            while let Some(chunk_result) = stream.next().await {
61                let chunk = match chunk_result {
62                    Ok(c) => c,
63                    Err(e) => {
64                        yield Err(CompletionError::from(e));
65                        break;
66                    }
67                };
68
69                let text = match String::from_utf8(chunk.to_vec()) {
70                    Ok(t) => t,
71                    Err(e) => {
72                        yield Err(CompletionError::ResponseError(e.to_string()));
73                        break;
74                    }
75                };
76
77
78                for line in text.lines() {
79                    let Some(line) = line.strip_prefix("data: ") else { continue; };
80
81                    let Ok(data) = serde_json::from_str::<StreamGenerateContentResponse>(line) else {
82                        continue;
83                    };
84
85                    let choice = data.candidates.first().expect("Should have at least one choice");
86
87                    match choice.content.parts.first() {
88                        super::completion::gemini_api_types::Part::Text(text)
89                            => yield Ok(streaming::RawStreamingChoice::Message(text)),
90                        super::completion::gemini_api_types::Part::FunctionCall(function_call)
91                            => yield Ok(streaming::RawStreamingChoice::ToolCall {
92                                    name: function_call.name,
93                                    id: "".to_string(),
94                                    arguments: function_call.args,
95                                    call_id: None
96                                }),
97                        _ => panic!("Unsupported response type with streaming.")
98                    };
99
100                    if choice.finish_reason.is_some() {
101                        yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
102                            usage_metadata: PartialUsage {
103                                total_token_count: data.usage_metadata.unwrap().total_token_count,
104                            }
105                        }))
106                    }
107                }
108            }
109        });
110
111        Ok(streaming::StreamingCompletionResponse::stream(stream))
112    }
113}