google_ai_rs/
chat.rs

1use std::{collections::HashMap, io::Write};
2
3use tokio::io::AsyncWrite;
4
5use crate::{
6    content::TryIntoContents,
7    error::{ActionError, Error, ServiceError},
8    genai::{GenerativeModel, ResponseStream as GenResponseStream},
9    proto::{part::Data, Candidate, CitationMetadata, Content, GenerateContentResponse, Part},
10};
11
12/// Interactive chat session maintaining conversation history
13///
14/// # Example
15/// ```
16/// # use google_ai_rs::{Client, GenerativeModel};
17/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
18/// # let auth = "YOUR-API-KEY";
19/// let client = Client::new(auth).await?;
20/// let model = client.generative_model("gemini-1.5-pro");
21/// let mut chat = model.start_chat();
22/// let response = chat.send_message("Hello!").await?;
23/// # Ok(())
24/// # }
25/// ```
26#[derive(Debug)]
27pub struct Session<'m> {
28    model: &'m GenerativeModel<'m>,
29    pub history: Vec<Content>,
30}
31
32impl GenerativeModel<'_> {
33    /// Starts a new chat session with empty history
34    pub fn start_chat(&self) -> Session<'_> {
35        Session {
36            model: self,
37            history: Vec::new(),
38        }
39    }
40}
41
42impl<'m> Session<'m> {
43    /// Sends a message and appends response to history
44    ///
45    /// # Errors
46    /// Returns [`Error::Service`] if no valid candidates in response
47    pub async fn send_message<T>(&mut self, contents: T) -> Result<GenerateContentResponse, Error>
48    where
49        T: TryIntoContents,
50    {
51        self.history.extend(contents.try_into_contents()?);
52
53        let response = self.model.generate_content(self.history.clone()).await?;
54
55        self.add_best_candidate_to_history(&response.candidates)
56            .ok_or(Error::Service(ServiceError::InvalidResponse(
57                "No valid candidates".into(),
58            )))?;
59
60        Ok(response)
61    }
62
63    /// Starts a streaming response while maintaining session state
64    ///
65    /// `NOTE`: response is only added to history if whole message is consumed
66    pub async fn stream_send_message<'s, T>(
67        &'s mut self,
68        contents: T,
69    ) -> Result<ResponseStream<'s, 'm>, Error>
70    where
71        T: TryIntoContents,
72    {
73        self.history.extend(contents.try_into_contents()?);
74
75        let stream = self
76            .model
77            .stream_generate_content(self.history.clone())
78            .await?;
79
80        Ok(ResponseStream {
81            inner: stream,
82            merged_candidates: Vec::new(),
83            session: self,
84            is_complete: false,
85        })
86    }
87
88    /// Adds the most appropriate candidate to chat history
89    fn add_best_candidate_to_history(&mut self, candidates: &[Candidate]) -> Option<()> {
90        candidates.first().and_then(|candidate| {
91            candidate.content.as_ref().map(|content| {
92                let mut model_content = content.clone();
93                model_content.role = "model".to_owned();
94                self.history.push(model_content);
95            })
96        })
97    }
98}
99
100/// Streaming response handler that maintains session continuity
101pub struct ResponseStream<'s, 'm> {
102    session: &'s mut Session<'m>,
103    inner: GenResponseStream,
104    merged_candidates: Vec<Candidate>,
105    is_complete: bool,
106}
107
108impl ResponseStream<'_, '_> {
109    /// Streams content chunks to any `Write` implementer
110    ///
111    /// # Returns
112    /// Total bytes written
113    pub async fn write_to<W: Write>(&mut self, dst: &mut W) -> Result<usize, Error> {
114        let mut total = 0;
115
116        while let Some(response) = self
117            .next()
118            .await
119            .map_err(|e| Error::Stream(ActionError::Error(e.into())))?
120        {
121            let bytes = response.try_into_bytes()?;
122            let written = dst
123                .write(&bytes)
124                .map_err(|e| Error::Stream(ActionError::Action(e)))?;
125            total += written;
126        }
127
128        Ok(total)
129    }
130
131    /// Streams content chunks to any `AsyncWrite` implementer
132    ///
133    /// # Returns
134    /// Total bytes written
135    pub async fn write_to_sync<W: AsyncWrite + std::marker::Unpin>(
136        &mut self,
137        dst: &mut W,
138    ) -> Result<usize, Error> {
139        use tokio::io::AsyncWriteExt;
140
141        let mut total = 0;
142
143        while let Some(response) = self
144            .next()
145            .await
146            .map_err(|e| Error::Stream(ActionError::Error(e.into())))?
147        {
148            let bytes = response.try_into_bytes()?;
149            let written = dst
150                .write(&bytes)
151                .await
152                .map_err(|e| Error::Stream(ActionError::Action(e)))?;
153            total += written;
154        }
155
156        Ok(total)
157    }
158
159    /// Retrieves next chunk of streaming response
160    pub async fn next(&mut self) -> Result<Option<GenerateContentResponse>, Error> {
161        if self.is_complete {
162            return Ok(None);
163        }
164
165        match self.inner.next().await? {
166            Some(response) => {
167                merge_candidates(&mut self.merged_candidates, &response.candidates);
168                Ok(Some(response))
169            }
170            None => {
171                self.session
172                    .add_best_candidate_to_history(&self.merged_candidates);
173                self.is_complete = true;
174                Ok(None)
175            }
176        }
177    }
178}
179
180/// Merges candidate lists from multiple response chunks
181pub fn merge_candidates(merged: &mut Vec<Candidate>, new_candidates: &[Candidate]) {
182    if merged.is_empty() {
183        merged.extend_from_slice(new_candidates);
184        return;
185    }
186
187    let candidate_map: HashMap<_, _> = new_candidates
188        .iter()
189        .filter_map(|c| c.index.as_ref().map(|i| (i, c)))
190        .collect();
191
192    for candidate in merged.iter_mut() {
193        if let Some(index) = &candidate.index {
194            if let Some(new_candidate) = candidate_map.get(index) {
195                merge_candidate_data(candidate, new_candidate);
196            }
197        }
198    }
199}
200
201/// Merges candidate content and metadata
202pub fn merge_candidate_data(target: &mut Candidate, source: &Candidate) {
203    // Merge content parts
204    if let Some(source_content) = &source.content {
205        target.content = match target.content.take() {
206            Some(existing) => Some(merge_content(existing, source_content.clone())),
207            None => Some(source_content.clone()),
208        };
209    }
210
211    // Update metadata
212    target.finish_reason.clone_from(&source.finish_reason);
213    target.safety_ratings.clone_from(&source.safety_ratings);
214
215    // Merge citations
216    if let Some(source_citations) = &source.citation_metadata {
217        target.citation_metadata = match target.citation_metadata.take() {
218            Some(existing) => Some(merge_citations(existing, source_citations)),
219            None => Some(source_citations.clone()),
220        };
221    }
222}
223
224/// Merges content parts while combining consecutive text elements
225pub fn merge_content(mut existing: Content, update: Content) -> Content {
226    existing.parts = merge_parts(existing.parts, update.parts);
227    existing
228}
229
230/// combines parts while merging adjacent text blocks
231pub fn merge_parts(mut existing: Vec<Part>, update: Vec<Part>) -> Vec<Part> {
232    let mut buffer = String::new();
233    let mut merged = Vec::new();
234
235    // Process existing parts
236    for part in existing.drain(..) {
237        if let Some(Data::Text(text)) = &part.data {
238            buffer.push_str(text);
239        } else {
240            if !buffer.is_empty() {
241                merged.push(Part::text(&buffer));
242                buffer.clear();
243            }
244            merged.push(part);
245        }
246    }
247
248    // Process new parts
249    for part in update {
250        if let Some(Data::Text(text)) = &part.data {
251            buffer.push_str(text);
252        } else {
253            if !buffer.is_empty() {
254                merged.push(Part::text(&buffer));
255                buffer.clear();
256            }
257            merged.push(part);
258        }
259    }
260
261    // Add remaining text
262    if !buffer.is_empty() {
263        merged.push(Part {
264            data: Some(Data::Text(buffer)),
265        });
266    }
267
268    merged
269}
270
271/// Combines citation metadata from multiple responses
272fn merge_citations(mut existing: CitationMetadata, update: &CitationMetadata) -> CitationMetadata {
273    existing
274        .citation_sources
275        .extend(update.citation_sources.iter().cloned());
276    existing
277}
278
279#[cfg(test)]
280mod tests {
281    use super::{merge_candidates, merge_parts};
282    use crate::{
283        content::IntoParts,
284        proto::{Candidate, Content, Part},
285    };
286
287    #[test]
288    fn _merge_candidates() {
289        let mut c1 = vec![
290            Candidate {
291                index: Some(2),
292                content: Some(Content::model("r1 i2")),
293                finish_reason: 1,
294                safety_ratings: vec![],
295                citation_metadata: None,
296                token_count: 0,
297                grounding_attributions: vec![],
298                grounding_metadata: None,
299                avg_logprobs: 0.0,
300                logprobs_result: None,
301            },
302            Candidate {
303                index: Some(0),
304                content: Some(Content::model("r1 i0")),
305                finish_reason: 2,
306                safety_ratings: vec![],
307                citation_metadata: None,
308                token_count: 0,
309                grounding_attributions: vec![],
310                grounding_metadata: None,
311                avg_logprobs: 0.0,
312                logprobs_result: None,
313            },
314        ];
315
316        let c2 = vec![
317            Candidate {
318                index: Some(0),
319                content: Some(Content::model(";r2 i0")),
320                finish_reason: 3,
321                safety_ratings: vec![],
322                citation_metadata: None,
323                token_count: 0,
324                grounding_attributions: vec![],
325                grounding_metadata: None,
326                avg_logprobs: 0.0,
327                logprobs_result: None,
328            },
329            Candidate {
330                index: Some(1),
331                content: Some(Content::model(";r2 i1")),
332                finish_reason: 4,
333                safety_ratings: vec![],
334                citation_metadata: None,
335                token_count: 0,
336                grounding_attributions: vec![],
337                grounding_metadata: None,
338                avg_logprobs: 0.0,
339                logprobs_result: None,
340            },
341        ];
342
343        let want = vec![
344            Candidate {
345                index: Some(2),
346                content: Some(Content::model("r1 i2")),
347                finish_reason: 1,
348                safety_ratings: vec![],
349                citation_metadata: None,
350                token_count: 0,
351                grounding_attributions: vec![],
352                grounding_metadata: None,
353                avg_logprobs: 0.0,
354                logprobs_result: None,
355            },
356            Candidate {
357                index: Some(0),
358                content: Some(Content::model("r1 i0;r2 i0")),
359                finish_reason: 3,
360                safety_ratings: vec![],
361                citation_metadata: None,
362                token_count: 0,
363                grounding_attributions: vec![],
364                grounding_metadata: None,
365                avg_logprobs: 0.0,
366                logprobs_result: None,
367            },
368        ];
369
370        merge_candidates(&mut c1, &c2);
371        assert_eq!(c1, want);
372        let mut c3 = vec![];
373        merge_candidates(&mut c3, &want);
374        assert_eq!(c3, want);
375    }
376
377    #[test]
378    fn merge_texts() {
379        struct Test {
380            update: Vec<Part>,
381            want: Vec<Part>,
382        }
383
384        let tests = vec![
385            Test {
386                update: vec![Part::text("a")],
387                want: vec![Part::text("a")],
388            },
389            Test {
390                update: vec![Part::text("a"), Part::text("b"), Part::text("c")],
391                want: vec![Part::text("abc")],
392            },
393            Test {
394                update: vec![
395                    Part::blob("b1", vec![]),
396                    Part::text("a"),
397                    Part::text("b"),
398                    Part::blob("b2", vec![]),
399                    Part::text("c"),
400                ],
401                want: vec![
402                    Part::blob("b1", vec![]),
403                    Part::text("ab"),
404                    Part::blob("b2", vec![]),
405                    Part::text("c"),
406                ],
407            },
408            Test {
409                update: vec![
410                    Part::text("a"),
411                    Part::text("b"),
412                    Part::blob("b1", vec![]),
413                    Part::text("c"),
414                    Part::text("d"),
415                    Part::blob("b2", vec![]),
416                ],
417                want: vec![
418                    Part::text("ab"),
419                    Part::blob("b1", vec![]),
420                    Part::text("cd"),
421                    Part::blob("b2", vec![]),
422                ],
423            },
424        ];
425
426        for test in tests {
427            assert_eq!(merge_parts(vec![], test.update), test.want)
428        }
429    }
430}