rig/providers/gemini/
streaming.rs1use async_stream::stream;
2use futures::StreamExt;
3use serde::Deserialize;
4
5use super::completion::{CompletionModel, create_request_body, gemini_api_types::ContentCandidate};
6use crate::{
7 completion::{CompletionError, CompletionRequest},
8 streaming::{self},
9};
10
11#[derive(Debug, Deserialize, Default, Clone)]
12#[serde(rename_all = "camelCase")]
13pub struct PartialUsage {
14 pub total_token_count: i32,
15}
16
17#[derive(Debug, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct StreamGenerateContentResponse {
20 pub candidates: Vec<ContentCandidate>,
22 pub model_version: Option<String>,
23 pub usage_metadata: Option<PartialUsage>,
24}
25
26#[derive(Clone, Debug)]
27pub struct StreamingCompletionResponse {
28 pub usage_metadata: PartialUsage,
29}
30
31impl CompletionModel {
32 pub(crate) async fn stream(
33 &self,
34 completion_request: CompletionRequest,
35 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
36 {
37 let request = create_request_body(completion_request)?;
38
39 let response = self
40 .client
41 .post_sse(&format!(
42 "/v1beta/models/{}:streamGenerateContent",
43 self.model
44 ))
45 .json(&request)
46 .send()
47 .await?;
48
49 if !response.status().is_success() {
50 return Err(CompletionError::ProviderError(format!(
51 "{}: {}",
52 response.status(),
53 response.text().await?
54 )));
55 }
56
57 let stream = Box::pin(stream! {
58 let mut stream = response.bytes_stream();
59
60 while let Some(chunk_result) = stream.next().await {
61 let chunk = match chunk_result {
62 Ok(c) => c,
63 Err(e) => {
64 yield Err(CompletionError::from(e));
65 break;
66 }
67 };
68
69 let text = match String::from_utf8(chunk.to_vec()) {
70 Ok(t) => t,
71 Err(e) => {
72 yield Err(CompletionError::ResponseError(e.to_string()));
73 break;
74 }
75 };
76
77
78 for line in text.lines() {
79 let Some(line) = line.strip_prefix("data: ") else { continue; };
80
81 let Ok(data) = serde_json::from_str::<StreamGenerateContentResponse>(line) else {
82 continue;
83 };
84
85 let choice = data.candidates.first().expect("Should have at least one choice");
86
87 match choice.content.parts.first() {
88 super::completion::gemini_api_types::Part::Text(text)
89 => yield Ok(streaming::RawStreamingChoice::Message(text)),
90 super::completion::gemini_api_types::Part::FunctionCall(function_call)
91 => yield Ok(streaming::RawStreamingChoice::ToolCall {
92 name: function_call.name,
93 id: "".to_string(),
94 arguments: function_call.args,
95 call_id: None
96 }),
97 _ => panic!("Unsupported response type with streaming.")
98 };
99
100 if choice.finish_reason.is_some() {
101 yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
102 usage_metadata: PartialUsage {
103 total_token_count: data.usage_metadata.unwrap().total_token_count,
104 }
105 }))
106 }
107 }
108 }
109 });
110
111 Ok(streaming::StreamingCompletionResponse::stream(stream))
112 }
113}