1use std::{collections::HashMap, io::Write};
2
3use tokio::io::AsyncWrite;
4
5use crate::{
6 content::TryIntoContents,
7 error::{ActionError, Error, ServiceError},
8 genai::{GenerativeModel, ResponseStream as GenResponseStream},
9 proto::{part::Data, Candidate, CitationMetadata, Content, GenerateContentResponse, Part},
10};
11
12#[derive(Debug)]
27pub struct Session<'m> {
28 model: &'m GenerativeModel<'m>,
29 pub history: Vec<Content>,
30}
31
32impl GenerativeModel<'_> {
33 pub fn start_chat(&self) -> Session<'_> {
35 Session {
36 model: self,
37 history: Vec::new(),
38 }
39 }
40}
41
42impl<'m> Session<'m> {
43 pub async fn send_message<T>(&mut self, contents: T) -> Result<GenerateContentResponse, Error>
48 where
49 T: TryIntoContents,
50 {
51 self.history.extend(contents.try_into_contents()?);
52
53 let response = self.model.generate_content(self.history.clone()).await?;
54
55 self.add_best_candidate_to_history(&response.candidates)
56 .ok_or(Error::Service(ServiceError::InvalidResponse(
57 "No valid candidates".into(),
58 )))?;
59
60 Ok(response)
61 }
62
63 pub async fn stream_send_message<'s, T>(
67 &'s mut self,
68 contents: T,
69 ) -> Result<ResponseStream<'s, 'm>, Error>
70 where
71 T: TryIntoContents,
72 {
73 self.history.extend(contents.try_into_contents()?);
74
75 let stream = self
76 .model
77 .stream_generate_content(self.history.clone())
78 .await?;
79
80 Ok(ResponseStream {
81 inner: stream,
82 merged_candidates: Vec::new(),
83 session: self,
84 is_complete: false,
85 })
86 }
87
88 fn add_best_candidate_to_history(&mut self, candidates: &[Candidate]) -> Option<()> {
90 candidates.first().and_then(|candidate| {
91 candidate.content.as_ref().map(|content| {
92 let mut model_content = content.clone();
93 model_content.role = "model".to_owned();
94 self.history.push(model_content);
95 })
96 })
97 }
98}
99
100pub struct ResponseStream<'s, 'm> {
102 session: &'s mut Session<'m>,
103 inner: GenResponseStream,
104 merged_candidates: Vec<Candidate>,
105 is_complete: bool,
106}
107
108impl ResponseStream<'_, '_> {
109 pub async fn write_to<W: Write>(&mut self, dst: &mut W) -> Result<usize, Error> {
114 let mut total = 0;
115
116 while let Some(response) = self
117 .next()
118 .await
119 .map_err(|e| Error::Stream(ActionError::Error(e.into())))?
120 {
121 let bytes = response.try_into_bytes()?;
122 let written = dst
123 .write(&bytes)
124 .map_err(|e| Error::Stream(ActionError::Action(e)))?;
125 total += written;
126 }
127
128 Ok(total)
129 }
130
131 pub async fn write_to_sync<W: AsyncWrite + std::marker::Unpin>(
136 &mut self,
137 dst: &mut W,
138 ) -> Result<usize, Error> {
139 use tokio::io::AsyncWriteExt;
140
141 let mut total = 0;
142
143 while let Some(response) = self
144 .next()
145 .await
146 .map_err(|e| Error::Stream(ActionError::Error(e.into())))?
147 {
148 let bytes = response.try_into_bytes()?;
149 let written = dst
150 .write(&bytes)
151 .await
152 .map_err(|e| Error::Stream(ActionError::Action(e)))?;
153 total += written;
154 }
155
156 Ok(total)
157 }
158
159 pub async fn next(&mut self) -> Result<Option<GenerateContentResponse>, Error> {
161 if self.is_complete {
162 return Ok(None);
163 }
164
165 match self.inner.next().await? {
166 Some(response) => {
167 merge_candidates(&mut self.merged_candidates, &response.candidates);
168 Ok(Some(response))
169 }
170 None => {
171 self.session
172 .add_best_candidate_to_history(&self.merged_candidates);
173 self.is_complete = true;
174 Ok(None)
175 }
176 }
177 }
178}
179
180pub fn merge_candidates(merged: &mut Vec<Candidate>, new_candidates: &[Candidate]) {
182 if merged.is_empty() {
183 merged.extend_from_slice(new_candidates);
184 return;
185 }
186
187 let candidate_map: HashMap<_, _> = new_candidates
188 .iter()
189 .filter_map(|c| c.index.as_ref().map(|i| (i, c)))
190 .collect();
191
192 for candidate in merged.iter_mut() {
193 if let Some(index) = &candidate.index {
194 if let Some(new_candidate) = candidate_map.get(index) {
195 merge_candidate_data(candidate, new_candidate);
196 }
197 }
198 }
199}
200
201pub fn merge_candidate_data(target: &mut Candidate, source: &Candidate) {
203 if let Some(source_content) = &source.content {
205 target.content = match target.content.take() {
206 Some(existing) => Some(merge_content(existing, source_content.clone())),
207 None => Some(source_content.clone()),
208 };
209 }
210
211 target.finish_reason.clone_from(&source.finish_reason);
213 target.safety_ratings.clone_from(&source.safety_ratings);
214
215 if let Some(source_citations) = &source.citation_metadata {
217 target.citation_metadata = match target.citation_metadata.take() {
218 Some(existing) => Some(merge_citations(existing, source_citations)),
219 None => Some(source_citations.clone()),
220 };
221 }
222}
223
224pub fn merge_content(mut existing: Content, update: Content) -> Content {
226 existing.parts = merge_parts(existing.parts, update.parts);
227 existing
228}
229
230pub fn merge_parts(mut existing: Vec<Part>, update: Vec<Part>) -> Vec<Part> {
232 let mut buffer = String::new();
233 let mut merged = Vec::new();
234
235 for part in existing.drain(..) {
237 if let Some(Data::Text(text)) = &part.data {
238 buffer.push_str(text);
239 } else {
240 if !buffer.is_empty() {
241 merged.push(Part::text(&buffer));
242 buffer.clear();
243 }
244 merged.push(part);
245 }
246 }
247
248 for part in update {
250 if let Some(Data::Text(text)) = &part.data {
251 buffer.push_str(text);
252 } else {
253 if !buffer.is_empty() {
254 merged.push(Part::text(&buffer));
255 buffer.clear();
256 }
257 merged.push(part);
258 }
259 }
260
261 if !buffer.is_empty() {
263 merged.push(Part {
264 data: Some(Data::Text(buffer)),
265 });
266 }
267
268 merged
269}
270
271fn merge_citations(mut existing: CitationMetadata, update: &CitationMetadata) -> CitationMetadata {
273 existing
274 .citation_sources
275 .extend(update.citation_sources.iter().cloned());
276 existing
277}
278
279#[cfg(test)]
280mod tests {
281 use super::{merge_candidates, merge_parts};
282 use crate::{
283 content::IntoParts,
284 proto::{Candidate, Content, Part},
285 };
286
287 #[test]
288 fn _merge_candidates() {
289 let mut c1 = vec![
290 Candidate {
291 index: Some(2),
292 content: Some(Content::model("r1 i2")),
293 finish_reason: 1,
294 safety_ratings: vec![],
295 citation_metadata: None,
296 token_count: 0,
297 grounding_attributions: vec![],
298 grounding_metadata: None,
299 avg_logprobs: 0.0,
300 logprobs_result: None,
301 },
302 Candidate {
303 index: Some(0),
304 content: Some(Content::model("r1 i0")),
305 finish_reason: 2,
306 safety_ratings: vec![],
307 citation_metadata: None,
308 token_count: 0,
309 grounding_attributions: vec![],
310 grounding_metadata: None,
311 avg_logprobs: 0.0,
312 logprobs_result: None,
313 },
314 ];
315
316 let c2 = vec![
317 Candidate {
318 index: Some(0),
319 content: Some(Content::model(";r2 i0")),
320 finish_reason: 3,
321 safety_ratings: vec![],
322 citation_metadata: None,
323 token_count: 0,
324 grounding_attributions: vec![],
325 grounding_metadata: None,
326 avg_logprobs: 0.0,
327 logprobs_result: None,
328 },
329 Candidate {
330 index: Some(1),
331 content: Some(Content::model(";r2 i1")),
332 finish_reason: 4,
333 safety_ratings: vec![],
334 citation_metadata: None,
335 token_count: 0,
336 grounding_attributions: vec![],
337 grounding_metadata: None,
338 avg_logprobs: 0.0,
339 logprobs_result: None,
340 },
341 ];
342
343 let want = vec![
344 Candidate {
345 index: Some(2),
346 content: Some(Content::model("r1 i2")),
347 finish_reason: 1,
348 safety_ratings: vec![],
349 citation_metadata: None,
350 token_count: 0,
351 grounding_attributions: vec![],
352 grounding_metadata: None,
353 avg_logprobs: 0.0,
354 logprobs_result: None,
355 },
356 Candidate {
357 index: Some(0),
358 content: Some(Content::model("r1 i0;r2 i0")),
359 finish_reason: 3,
360 safety_ratings: vec![],
361 citation_metadata: None,
362 token_count: 0,
363 grounding_attributions: vec![],
364 grounding_metadata: None,
365 avg_logprobs: 0.0,
366 logprobs_result: None,
367 },
368 ];
369
370 merge_candidates(&mut c1, &c2);
371 assert_eq!(c1, want);
372 let mut c3 = vec![];
373 merge_candidates(&mut c3, &want);
374 assert_eq!(c3, want);
375 }
376
377 #[test]
378 fn merge_texts() {
379 struct Test {
380 update: Vec<Part>,
381 want: Vec<Part>,
382 }
383
384 let tests = vec![
385 Test {
386 update: vec![Part::text("a")],
387 want: vec![Part::text("a")],
388 },
389 Test {
390 update: vec![Part::text("a"), Part::text("b"), Part::text("c")],
391 want: vec![Part::text("abc")],
392 },
393 Test {
394 update: vec![
395 Part::blob("b1", vec![]),
396 Part::text("a"),
397 Part::text("b"),
398 Part::blob("b2", vec![]),
399 Part::text("c"),
400 ],
401 want: vec![
402 Part::blob("b1", vec![]),
403 Part::text("ab"),
404 Part::blob("b2", vec![]),
405 Part::text("c"),
406 ],
407 },
408 Test {
409 update: vec![
410 Part::text("a"),
411 Part::text("b"),
412 Part::blob("b1", vec![]),
413 Part::text("c"),
414 Part::text("d"),
415 Part::blob("b2", vec![]),
416 ],
417 want: vec![
418 Part::text("ab"),
419 Part::blob("b1", vec![]),
420 Part::text("cd"),
421 Part::blob("b2", vec![]),
422 ],
423 },
424 ];
425
426 for test in tests {
427 assert_eq!(merge_parts(vec![], test.update), test.want)
428 }
429 }
430}