Skip to main content

gproxy_protocol/transform/gemini/
stream_to_nonstream.rs

1use std::collections::BTreeMap;
2
3use crate::gemini::generate_content::request::GeminiGenerateContentRequest;
4use crate::gemini::generate_content::response::ResponseBody as GeminiGenerateContentResponseBody;
5use crate::gemini::generate_content::types::GeminiCandidate;
6use crate::gemini::stream_generate_content::request::GeminiStreamGenerateContentRequest;
7use crate::transform::utils::TransformError;
8
9fn merge_candidate(target: &mut GeminiCandidate, incoming: GeminiCandidate, index: u32) {
10    target.index = Some(index);
11
12    if let Some(content) = incoming.content {
13        if let Some(target_content) = target.content.as_mut() {
14            if target_content.role.is_none() {
15                target_content.role = content.role;
16            }
17            target_content.parts.extend(content.parts);
18        } else {
19            target.content = Some(content);
20        }
21    }
22
23    if incoming.finish_reason.is_some() {
24        target.finish_reason = incoming.finish_reason;
25    }
26    if incoming.safety_ratings.is_some() {
27        target.safety_ratings = incoming.safety_ratings;
28    }
29    if incoming.citation_metadata.is_some() {
30        target.citation_metadata = incoming.citation_metadata;
31    }
32    if incoming.token_count.is_some() {
33        target.token_count = incoming.token_count;
34    }
35    if let Some(grounding_attributions) = incoming.grounding_attributions {
36        if let Some(existing) = target.grounding_attributions.as_mut() {
37            existing.extend(grounding_attributions);
38        } else {
39            target.grounding_attributions = Some(grounding_attributions);
40        }
41    }
42    if incoming.grounding_metadata.is_some() {
43        target.grounding_metadata = incoming.grounding_metadata;
44    }
45    if incoming.avg_logprobs.is_some() {
46        target.avg_logprobs = incoming.avg_logprobs;
47    }
48    if incoming.logprobs_result.is_some() {
49        target.logprobs_result = incoming.logprobs_result;
50    }
51    if incoming.url_context_metadata.is_some() {
52        target.url_context_metadata = incoming.url_context_metadata;
53    }
54    if incoming.finish_message.is_some() {
55        target.finish_message = incoming.finish_message;
56    }
57}
58
59pub fn merge_chunk(
60    merged: &mut GeminiGenerateContentResponseBody,
61    candidate_map: &mut BTreeMap<u32, GeminiCandidate>,
62    chunk: GeminiGenerateContentResponseBody,
63) {
64    if let Some(candidates) = chunk.candidates {
65        for (pos, candidate) in candidates.into_iter().enumerate() {
66            let index = candidate.index.unwrap_or(pos as u32);
67            let entry = candidate_map
68                .entry(index)
69                .or_insert_with(|| GeminiCandidate {
70                    index: Some(index),
71                    ..GeminiCandidate::default()
72                });
73            merge_candidate(entry, candidate, index);
74        }
75    }
76
77    if chunk.prompt_feedback.is_some() {
78        merged.prompt_feedback = chunk.prompt_feedback;
79    }
80    if chunk.usage_metadata.is_some() {
81        merged.usage_metadata = chunk.usage_metadata;
82    }
83    if chunk.model_version.is_some() {
84        merged.model_version = chunk.model_version;
85    }
86    if chunk.response_id.is_some() {
87        merged.response_id = chunk.response_id;
88    }
89    if chunk.model_status.is_some() {
90        merged.model_status = chunk.model_status;
91    }
92}
93
94pub fn finalize_body(
95    mut merged: GeminiGenerateContentResponseBody,
96    candidate_map: BTreeMap<u32, GeminiCandidate>,
97) -> GeminiGenerateContentResponseBody {
98    if candidate_map.is_empty() {
99        merged.candidates = None;
100    } else {
101        merged.candidates = Some(candidate_map.into_values().collect());
102    }
103    merged
104}
105
106impl TryFrom<&GeminiStreamGenerateContentRequest> for GeminiGenerateContentRequest {
107    type Error = TransformError;
108
109    fn try_from(value: &GeminiStreamGenerateContentRequest) -> Result<Self, TransformError> {
110        Ok(GeminiGenerateContentRequest {
111            method: value.method,
112            path: crate::gemini::generate_content::request::PathParameters {
113                model: value.path.model.clone(),
114            },
115            query: crate::gemini::generate_content::request::QueryParameters::default(),
116            headers: crate::gemini::generate_content::request::RequestHeaders::default(),
117            body: value.body.clone(),
118        })
119    }
120}
121
122impl TryFrom<GeminiStreamGenerateContentRequest> for GeminiGenerateContentRequest {
123    type Error = TransformError;
124
125    fn try_from(value: GeminiStreamGenerateContentRequest) -> Result<Self, TransformError> {
126        GeminiGenerateContentRequest::try_from(&value)
127    }
128}