rig/providers/gemini/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::Deserialize;
4
5use crate::{
6    completion::{CompletionError, CompletionRequest},
7    streaming::{self, StreamingCompletionModel, StreamingResult},
8};
9
10use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
11
12#[derive(Debug, Deserialize)]
13#[serde(rename_all = "camelCase")]
14pub struct StreamGenerateContentResponse {
15    /// Candidate responses from the model.
16    pub candidates: Vec<ContentCandidate>,
17    pub model_version: Option<String>,
18}
19
20impl StreamingCompletionModel for CompletionModel {
21    async fn stream(
22        &self,
23        completion_request: CompletionRequest,
24    ) -> Result<StreamingResult, CompletionError> {
25        let request = create_request_body(completion_request)?;
26
27        let response = self
28            .client
29            .post_sse(&format!(
30                "/v1beta/models/{}:streamGenerateContent",
31                self.model
32            ))
33            .json(&request)
34            .send()
35            .await?;
36
37        if !response.status().is_success() {
38            return Err(CompletionError::ProviderError(format!(
39                "{}: {}",
40                response.status(),
41                response.text().await?
42            )));
43        }
44
45        Ok(Box::pin(stream! {
46            let mut stream = response.bytes_stream();
47
48            while let Some(chunk_result) = stream.next().await {
49                let chunk = match chunk_result {
50                    Ok(c) => c,
51                    Err(e) => {
52                        yield Err(CompletionError::from(e));
53                        break;
54                    }
55                };
56
57                let text = match String::from_utf8(chunk.to_vec()) {
58                    Ok(t) => t,
59                    Err(e) => {
60                        yield Err(CompletionError::ResponseError(e.to_string()));
61                        break;
62                    }
63                };
64
65
66                for line in text.lines() {
67                    let Some(line) = line.strip_prefix("data: ") else { continue; };
68
69                    let Ok(data) = serde_json::from_str::<StreamGenerateContentResponse>(line) else {
70                        continue;
71                    };
72
73                    let choice = data.candidates.first().expect("Should have at least one choice");
74
75                    match choice.content.parts.first() {
76                        super::completion::gemini_api_types::Part::Text(text)
77                            => yield Ok(streaming::StreamingChoice::Message(text)),
78                        super::completion::gemini_api_types::Part::FunctionCall(function_call)
79                            => yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
80                        _ => panic!("Unsupported response type with streaming.")
81                    };
82                }
83            }
84        }))
85    }
86}