rig/providers/gemini/
streaming.rs1use async_stream::stream;
2use futures::StreamExt;
3use serde::Deserialize;
4
5use crate::{
6 completion::{CompletionError, CompletionRequest},
7 streaming::{self, StreamingCompletionModel, StreamingResult},
8};
9
10use super::completion::{create_request_body, gemini_api_types::ContentCandidate, CompletionModel};
11
12#[derive(Debug, Deserialize)]
13#[serde(rename_all = "camelCase")]
14pub struct StreamGenerateContentResponse {
15 pub candidates: Vec<ContentCandidate>,
17 pub model_version: Option<String>,
18}
19
20impl StreamingCompletionModel for CompletionModel {
21 async fn stream(
22 &self,
23 completion_request: CompletionRequest,
24 ) -> Result<StreamingResult, CompletionError> {
25 let request = create_request_body(completion_request)?;
26
27 let response = self
28 .client
29 .post_sse(&format!(
30 "/v1beta/models/{}:streamGenerateContent",
31 self.model
32 ))
33 .json(&request)
34 .send()
35 .await?;
36
37 if !response.status().is_success() {
38 return Err(CompletionError::ProviderError(format!(
39 "{}: {}",
40 response.status(),
41 response.text().await?
42 )));
43 }
44
45 Ok(Box::pin(stream! {
46 let mut stream = response.bytes_stream();
47
48 while let Some(chunk_result) = stream.next().await {
49 let chunk = match chunk_result {
50 Ok(c) => c,
51 Err(e) => {
52 yield Err(CompletionError::from(e));
53 break;
54 }
55 };
56
57 let text = match String::from_utf8(chunk.to_vec()) {
58 Ok(t) => t,
59 Err(e) => {
60 yield Err(CompletionError::ResponseError(e.to_string()));
61 break;
62 }
63 };
64
65
66 for line in text.lines() {
67 let Some(line) = line.strip_prefix("data: ") else { continue; };
68
69 let Ok(data) = serde_json::from_str::<StreamGenerateContentResponse>(line) else {
70 continue;
71 };
72
73 let choice = data.candidates.first().expect("Should have at least one choice");
74
75 match choice.content.parts.first() {
76 super::completion::gemini_api_types::Part::Text(text)
77 => yield Ok(streaming::StreamingChoice::Message(text)),
78 super::completion::gemini_api_types::Part::FunctionCall(function_call)
79 => yield Ok(streaming::StreamingChoice::ToolCall(function_call.name, "".to_string(), function_call.args)),
80 _ => panic!("Unsupported response type with streaming.")
81 };
82 }
83 }
84 }))
85 }
86}