gemini_live_api/client/
handle.rs

1use crate::error::GeminiError;
2use crate::types::*;
3use base64::Engine as _;
4use std::sync::Arc;
5use tokio::sync::{mpsc, oneshot};
6use tracing::{error, info, trace, warn};
7
8use super::GeminiLiveClientBuilder;
9
10pub struct GeminiLiveClient<S: Clone + Send + Sync + 'static> {
11    pub(crate) shutdown_tx: Option<oneshot::Sender<()>>,
12    pub(crate) outgoing_sender: Option<mpsc::Sender<ClientMessagePayload>>,
13    pub(crate) state: Arc<S>,
14}
15
16impl<S: Clone + Send + Sync + 'static> GeminiLiveClient<S> {
17    pub async fn close(&mut self) -> Result<(), GeminiError> {
18        info!("Client close requested.");
19        if let Some(tx) = self.shutdown_tx.take() {
20            if tx.send(()).is_err() {
21                info!("Shutdown signal failed: Listener task already gone.");
22            } else {
23                info!("Shutdown signal sent to listener task.");
24            }
25        }
26        self.outgoing_sender.take();
27        Ok(())
28    }
29
30    pub fn get_outgoing_mpsc_sender_clone(
31        &self,
32    ) -> Option<tokio::sync::mpsc::Sender<ClientMessagePayload>> {
33        self.outgoing_sender.clone()
34    }
35
36    pub fn builder_with_state(
37        api_key: String,
38        model: String,
39        state: S,
40    ) -> GeminiLiveClientBuilder<S> {
41        GeminiLiveClientBuilder::new_with_state(api_key, model, state)
42    }
43
44    async fn send_message(&self, payload: ClientMessagePayload) -> Result<(), GeminiError> {
45        if let Some(sender) = &self.outgoing_sender {
46            let sender = sender.clone();
47            match sender.send(payload).await {
48                Ok(_) => {
49                    trace!("Message sent to listener task via channel.");
50                    Ok(())
51                }
52                Err(_) => {
53                    error!("Failed to send message to listener task: Channel closed.");
54                    Err(GeminiError::SendError)
55                }
56            }
57        } else {
58            error!("Cannot send message: Client is closed or sender missing.");
59            Err(GeminiError::NotReady)
60        }
61    }
62
63    pub async fn send_text_turn(&self, text: String, end_of_turn: bool) -> Result<(), GeminiError> {
64        let content_part = Part {
65            text: Some(text),
66            ..Default::default()
67        };
68        let content = Content {
69            parts: vec![content_part],
70            role: Some(Role::User),
71        };
72        let client_content_msg = BidiGenerateContentClientContent {
73            turns: Some(vec![content]),
74            turn_complete: Some(end_of_turn),
75        };
76        self.send_message(ClientMessagePayload::ClientContent(client_content_msg))
77            .await
78    }
79
80    pub async fn send_audio_chunk(
81        &self,
82        audio_samples: &[i16],
83        sample_rate: u32,
84        channels: u16,
85    ) -> Result<(), GeminiError> {
86        if audio_samples.is_empty() {
87            return Ok(());
88        }
89        let mut byte_data = Vec::with_capacity(audio_samples.len() * 2);
90        for sample in audio_samples {
91            byte_data.extend_from_slice(&sample.to_le_bytes());
92        }
93
94        let encoded_data = base64::engine::general_purpose::STANDARD.encode(&byte_data);
95        let mime_type = format!("audio/pcm;rate={}", sample_rate);
96
97        let audio_blob = Blob {
98            mime_type,
99            data: encoded_data,
100        };
101
102        let realtime_input = BidiGenerateContentRealtimeInput {
103            audio: Some(audio_blob),
104            ..Default::default()
105        };
106
107        self.send_message(ClientMessagePayload::RealtimeInput(realtime_input))
108            .await
109    }
110
111    pub async fn send_realtime_text(&self, text: String) -> Result<(), GeminiError> {
112        let realtime_input = BidiGenerateContentRealtimeInput {
113            text: Some(text),
114            ..Default::default()
115        };
116        self.send_message(ClientMessagePayload::RealtimeInput(realtime_input))
117            .await
118    }
119
120    pub async fn send_activity_start(&self) -> Result<(), GeminiError> {
121        let realtime_input = BidiGenerateContentRealtimeInput {
122            activity_start: Some(ActivityStart {}),
123            ..Default::default()
124        };
125        self.send_message(ClientMessagePayload::RealtimeInput(realtime_input))
126            .await
127    }
128
129    pub async fn send_activity_end(&self) -> Result<(), GeminiError> {
130        let realtime_input = BidiGenerateContentRealtimeInput {
131            activity_end: Some(ActivityEnd {}),
132            ..Default::default()
133        };
134        self.send_message(ClientMessagePayload::RealtimeInput(realtime_input))
135            .await
136    }
137
138    pub async fn send_audio_stream_end(&self) -> Result<(), GeminiError> {
139        let realtime_input = BidiGenerateContentRealtimeInput {
140            audio_stream_end: Some(true),
141            ..Default::default()
142        };
143        self.send_message(ClientMessagePayload::RealtimeInput(realtime_input))
144            .await
145    }
146
147    pub fn state(&self) -> Arc<S> {
148        self.state.clone()
149    }
150}
151
152impl<S: Clone + Send + Sync + 'static> Drop for GeminiLiveClient<S> {
153    fn drop(&mut self) {
154        if self.shutdown_tx.is_some() {
155            warn!("GeminiLiveClient dropped without explicit close(). Attempting shutdown.");
156            if let Some(tx) = self.shutdown_tx.take() {
157                let _ = tx.send(());
158            }
159            self.outgoing_sender.take();
160        }
161    }
162}