use crate::config::{RealtimeConfig, SessionUpdateConfig, ToolDefinition};
use crate::error::{RealtimeError, Result};
use crate::events::{ServerEvent, ToolCall, ToolResponse};
use crate::model::BoxedModel;
use crate::session::ContextMutationOutcome;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Default, PartialEq)]
pub enum RunnerState {
#[default]
Idle,
Generating,
ExecutingTool,
PendingResumption {
config: Box<crate::config::RealtimeConfig>,
bridge_message: Option<String>,
attempts: u8,
},
}
#[async_trait]
pub trait ToolHandler: Send + Sync {
async fn execute(&self, call: &ToolCall) -> Result<serde_json::Value>;
}
pub struct FnToolHandler<F>
where
F: Fn(&ToolCall) -> Result<serde_json::Value> + Send + Sync,
{
handler: F,
}
impl<F> FnToolHandler<F>
where
F: Fn(&ToolCall) -> Result<serde_json::Value> + Send + Sync,
{
pub fn new(handler: F) -> Self {
Self { handler }
}
}
#[async_trait]
impl<F> ToolHandler for FnToolHandler<F>
where
F: Fn(&ToolCall) -> Result<serde_json::Value> + Send + Sync,
{
async fn execute(&self, call: &ToolCall) -> Result<serde_json::Value> {
(self.handler)(call)
}
}
#[allow(dead_code)]
pub struct AsyncToolHandler<F, Fut>
where
F: Fn(ToolCall) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<serde_json::Value>> + Send,
{
handler: F,
}
impl<F, Fut> AsyncToolHandler<F, Fut>
where
F: Fn(ToolCall) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<serde_json::Value>> + Send,
{
pub fn new(handler: F) -> Self {
Self { handler }
}
}
#[async_trait]
pub trait EventHandler: Send + Sync {
async fn on_audio(&self, _audio: &[u8], _item_id: &str) -> Result<()> {
Ok(())
}
async fn on_text(&self, _text: &str, _item_id: &str) -> Result<()> {
Ok(())
}
async fn on_transcript(&self, _transcript: &str, _item_id: &str) -> Result<()> {
Ok(())
}
async fn on_speech_started(&self, _audio_start_ms: u64) -> Result<()> {
Ok(())
}
async fn on_speech_stopped(&self, _audio_end_ms: u64) -> Result<()> {
Ok(())
}
async fn on_response_done(&self) -> Result<()> {
Ok(())
}
async fn on_error(&self, _error: &RealtimeError) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct NoOpEventHandler;
#[async_trait]
impl EventHandler for NoOpEventHandler {}
#[derive(Clone)]
pub struct RunnerConfig {
pub auto_execute_tools: bool,
pub auto_respond_tools: bool,
pub max_concurrent_tools: usize,
}
impl Default for RunnerConfig {
fn default() -> Self {
Self { auto_execute_tools: true, auto_respond_tools: true, max_concurrent_tools: 4 }
}
}
pub struct RealtimeRunnerBuilder {
model: Option<BoxedModel>,
config: RealtimeConfig,
runner_config: RunnerConfig,
tools: HashMap<String, (ToolDefinition, Arc<dyn ToolHandler>)>,
event_handler: Option<Arc<dyn EventHandler>>,
}
impl Default for RealtimeRunnerBuilder {
fn default() -> Self {
Self::new()
}
}
impl RealtimeRunnerBuilder {
pub fn new() -> Self {
Self {
model: None,
config: RealtimeConfig::default(),
runner_config: RunnerConfig::default(),
tools: HashMap::new(),
event_handler: None,
}
}
pub fn model(mut self, model: BoxedModel) -> Self {
self.model = Some(model);
self
}
pub fn config(mut self, config: RealtimeConfig) -> Self {
self.config = config;
self
}
pub fn runner_config(mut self, config: RunnerConfig) -> Self {
self.runner_config = config;
self
}
pub fn instruction(mut self, instruction: impl Into<String>) -> Self {
self.config.instruction = Some(instruction.into());
self
}
pub fn voice(mut self, voice: impl Into<String>) -> Self {
self.config.voice = Some(voice.into());
self
}
pub fn tool(mut self, definition: ToolDefinition, handler: impl ToolHandler + 'static) -> Self {
let name = definition.name.clone();
self.tools.insert(name, (definition, Arc::new(handler)));
self
}
pub fn tool_fn<F>(self, definition: ToolDefinition, handler: F) -> Self
where
F: Fn(&ToolCall) -> Result<serde_json::Value> + Send + Sync + 'static,
{
self.tool(definition, FnToolHandler::new(handler))
}
pub fn event_handler(mut self, handler: impl EventHandler + 'static) -> Self {
self.event_handler = Some(Arc::new(handler));
self
}
pub fn build(self) -> Result<RealtimeRunner> {
let model = self.model.ok_or_else(|| RealtimeError::config("Model is required"))?;
let mut config = self.config;
if !self.tools.is_empty() {
let tool_defs: Vec<ToolDefinition> =
self.tools.values().map(|(def, _)| def.clone()).collect();
config.tools = Some(tool_defs);
}
Ok(RealtimeRunner {
model,
config: Arc::new(RwLock::new(config)),
runner_config: self.runner_config,
tools: self.tools,
event_handler: self.event_handler.unwrap_or_else(|| Arc::new(NoOpEventHandler)),
session: Arc::new(RwLock::new(None)),
state: Arc::new(RwLock::new(RunnerState::Idle)),
})
}
}
pub struct RealtimeRunner {
model: BoxedModel,
config: Arc<RwLock<RealtimeConfig>>,
runner_config: RunnerConfig,
tools: HashMap<String, (ToolDefinition, Arc<dyn ToolHandler>)>,
event_handler: Arc<dyn EventHandler>,
session: Arc<RwLock<Option<Arc<dyn crate::session::RealtimeSession>>>>,
state: Arc<RwLock<RunnerState>>,
}
impl RealtimeRunner {
async fn session_handle(&self) -> Result<Arc<dyn crate::session::RealtimeSession>> {
let guard = self.session.read().await;
guard.as_ref().cloned().ok_or_else(|| RealtimeError::connection("Not connected"))
}
pub fn builder() -> RealtimeRunnerBuilder {
RealtimeRunnerBuilder::new()
}
pub async fn connect(&self) -> Result<()> {
let config = self.config.read().await.clone();
let session = self.model.connect(config).await?;
let mut guard = self.session.write().await;
*guard = Some(session.into());
Ok(())
}
pub async fn is_connected(&self) -> bool {
let guard = self.session.read().await;
guard.as_ref().map(|s| s.is_connected()).unwrap_or(false)
}
pub async fn session_id(&self) -> Option<String> {
let guard = self.session.read().await;
guard.as_ref().map(|s| s.session_id().to_string())
}
pub async fn send_client_event(&self, event: crate::events::ClientEvent) -> Result<()> {
match event {
crate::events::ClientEvent::UpdateSession { instructions, tools } => {
let update_config = SessionUpdateConfig(crate::config::RealtimeConfig {
instruction: instructions,
tools,
..Default::default()
});
self.update_session(update_config).await
}
other => {
let session = self.session_handle().await?;
session.send_event(other).await
}
}
}
fn merge_config(base: &mut RealtimeConfig, update: &SessionUpdateConfig) {
if let Some(instruction) = &update.0.instruction {
base.instruction = Some(instruction.clone());
}
if let Some(tools) = &update.0.tools {
base.tools = Some(tools.clone());
}
if let Some(voice) = &update.0.voice {
base.voice = Some(voice.clone());
}
if let Some(temp) = update.0.temperature {
base.temperature = Some(temp);
}
if let Some(extra) = &update.0.extra {
base.extra = Some(extra.clone());
}
}
pub async fn update_session(&self, config: SessionUpdateConfig) -> Result<()> {
self.update_session_with_bridge(config, None).await
}
pub async fn update_session_with_bridge(
&self,
config: SessionUpdateConfig,
bridge_message: Option<String>,
) -> Result<()> {
let mut full_config = self.config.write().await;
Self::merge_config(&mut full_config, &config);
let cloned_config = full_config.clone();
drop(full_config);
let session = self.session_handle().await?;
match session.mutate_context(cloned_config).await? {
ContextMutationOutcome::Applied => {
tracing::info!("Context mutated natively mid-flight.");
if let Some(msg) = bridge_message {
let event = crate::events::ClientEvent::Message {
role: "user".to_string(),
parts: vec![adk_core::types::Part::Text { text: msg }],
};
session.send_event(event).await?;
}
Ok(())
}
ContextMutationOutcome::RequiresResumption(new_config) => {
drop(session);
let mut state_guard = self.state.write().await;
if *state_guard == RunnerState::Idle {
drop(state_guard); tracing::info!("Runner is idle. Executing resumption immediately.");
if let Err(e) =
self.execute_resumption((*new_config).clone(), bridge_message.clone()).await
{
tracing::error!("Immediate resumption failed: {}. Queueing for retry.", e);
let mut fallback_state = self.state.write().await;
*fallback_state = RunnerState::PendingResumption {
config: Box::new(*new_config),
bridge_message,
attempts: 1,
};
return Err(e);
}
} else {
if let RunnerState::PendingResumption { .. } = *state_guard {
tracing::warn!(
"Runner already had a pending resumption. Overwriting with last-write-wins policy."
);
} else {
tracing::info!("Runner is busy ({:?}). Queueing resumption.", *state_guard);
}
*state_guard = RunnerState::PendingResumption {
config: new_config,
bridge_message,
attempts: 0,
};
}
Ok(())
}
}
}
async fn execute_resumption(
&self,
new_config: crate::config::RealtimeConfig,
bridge_message: Option<String>,
) -> Result<()> {
tracing::warn!("Executing transport resumption with new configuration.");
let old_session = {
let mut write_guard = self.session.write().await;
write_guard.take()
};
if let Some(session) = old_session {
if let Err(e) = session.close().await {
tracing::warn!("Failed to cleanly close old session during resumption: {}", e);
}
}
let new_session = self.model.connect(new_config).await?;
{
let mut write_guard = self.session.write().await;
*write_guard = Some(new_session.into());
}
if let Some(msg) = bridge_message {
self.inject_bridge_message(msg).await?;
}
tracing::info!("Resumption complete. New transport established.");
Ok(())
}
async fn inject_bridge_message(&self, msg: String) -> Result<()> {
tracing::info!("Injecting bridge message post-resumption.");
let event = crate::events::ClientEvent::Message {
role: "user".to_string(),
parts: vec![adk_core::types::Part::Text { text: msg }],
};
let session = self.session_handle().await?;
session.send_event(event).await
}
pub async fn send_audio(&self, audio_base64: &str) -> Result<()> {
let session = self.session_handle().await?;
session.send_audio_base64(audio_base64).await
}
pub async fn send_text(&self, text: &str) -> Result<()> {
let session = self.session_handle().await?;
session.send_text(text).await
}
pub async fn commit_audio(&self) -> Result<()> {
let session = self.session_handle().await?;
session.commit_audio().await
}
pub async fn create_response(&self) -> Result<()> {
let session = self.session_handle().await?;
session.create_response().await
}
pub async fn interrupt(&self) -> Result<()> {
let session = self.session_handle().await?;
session.interrupt().await
}
pub async fn next_event(&self) -> Option<Result<ServerEvent>> {
let session = match self.session_handle().await {
Ok(session) => session,
Err(_) => {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
return None;
}
};
tokio::task::yield_now().await;
session.next_event().await
}
pub async fn send_tool_response(&self, response: ToolResponse) -> Result<()> {
let session = self.session_handle().await?;
session.send_tool_response(response).await
}
pub async fn run(&self) -> Result<()> {
loop {
let session = self.session_handle().await?;
let old_session_id = session.session_id().to_string();
let event = session.next_event().await;
match event {
Some(Ok(event)) => {
self.handle_event(event).await?;
}
Some(Err(e)) => {
self.event_handler.on_error(&e).await?;
return Err(e);
}
None => {
let current_session_id = self.session_id().await;
if let Some(id) = current_session_id {
if id != old_session_id {
continue;
}
}
break;
}
}
}
Ok(())
}
async fn handle_event(&self, event: ServerEvent) -> Result<()> {
match &event {
ServerEvent::ResponseCreated { .. } => {
let mut state = self.state.write().await;
if let RunnerState::Idle = *state {
*state = RunnerState::Generating;
}
}
ServerEvent::FunctionCallDone { .. } => {
let mut state = self.state.write().await;
if let RunnerState::Generating | RunnerState::Idle = *state {
*state = RunnerState::ExecutingTool;
}
}
_ => {}
}
match event {
ServerEvent::AudioDelta { delta, item_id, .. } => {
self.event_handler.on_audio(&delta, &item_id).await?;
}
ServerEvent::TextDelta { delta, item_id, .. } => {
self.event_handler.on_text(&delta, &item_id).await?;
}
ServerEvent::TranscriptDelta { delta, item_id, .. } => {
self.event_handler.on_transcript(&delta, &item_id).await?;
}
ServerEvent::SpeechStarted { audio_start_ms, .. } => {
self.event_handler.on_speech_started(audio_start_ms).await?;
}
ServerEvent::SpeechStopped { audio_end_ms, .. } => {
self.event_handler.on_speech_stopped(audio_end_ms).await?;
}
ServerEvent::ResponseDone { .. } => {
self.event_handler.on_response_done().await?;
self.check_resumption_queue().await?;
}
ServerEvent::FunctionCallDone { call_id, name, arguments, .. } => {
if self.runner_config.auto_execute_tools {
self.execute_tool_call(&call_id, &name, &arguments).await?;
}
}
ServerEvent::SessionUpdated { session, .. } => {
if let Some(token) = session.get("resumeToken").and_then(|t| t.as_str()) {
tracing::info!(
"Received Gemini sessionResumption token, saving for future reconnects."
);
let mut config = self.config.write().await;
let mut extra = config.extra.clone().unwrap_or_else(|| serde_json::json!({}));
extra["resumeToken"] = serde_json::Value::String(token.to_string());
config.extra = Some(extra);
}
}
ServerEvent::Error { error, .. } => {
let err = RealtimeError::server(error.code.unwrap_or_default(), error.message);
self.event_handler.on_error(&err).await?;
}
_ => {
}
}
Ok(())
}
async fn check_resumption_queue(&self) -> Result<()> {
let mut state = self.state.write().await;
let pending =
if let RunnerState::PendingResumption { config, bridge_message, attempts } = &*state {
Some((config.clone(), bridge_message.clone(), *attempts))
} else {
None
};
if let Some((config, bridge_message, attempts)) = pending {
tracing::info!(
"Executing queued resumption after turn completion. (Attempt {})",
attempts + 1
);
*state = RunnerState::Idle;
drop(state);
if let Err(e) = self.execute_resumption((*config).clone(), bridge_message.clone()).await
{
tracing::error!("Resumption failed: {}.", e);
let mut fallback_state = self.state.write().await;
if attempts + 1 >= 3 {
tracing::error!(
"Maximum resumption attempts reached (3). Dropping queued mutation to prevent infinite loop."
);
*fallback_state = RunnerState::Idle;
} else {
tracing::info!("Restoring pending queue state for retry.");
*fallback_state = RunnerState::PendingResumption {
config,
bridge_message,
attempts: attempts + 1,
};
}
let _ = self.event_handler.on_error(&e).await;
}
} else {
*state = RunnerState::Idle;
}
Ok(())
}
async fn execute_tool_call(&self, call_id: &str, name: &str, arguments: &str) -> Result<()> {
let handler = self.tools.get(name).map(|(_, h)| h.clone());
let result = if let Some(handler) = handler {
let args: serde_json::Value = serde_json::from_str(arguments)
.unwrap_or(serde_json::Value::Object(Default::default()));
let call =
ToolCall { call_id: call_id.to_string(), name: name.to_string(), arguments: args };
match handler.execute(&call).await {
Ok(value) => value,
Err(e) => serde_json::json!({
"error": e.to_string()
}),
}
} else {
serde_json::json!({
"error": format!("Unknown tool: {}", name)
})
};
if self.runner_config.auto_respond_tools {
let response = ToolResponse { call_id: call_id.to_string(), output: result };
if let Ok(session) = self.session_handle().await {
session.send_tool_response(response).await?;
}
}
Ok(())
}
pub async fn close(&self) -> Result<()> {
if let Ok(session) = self.session_handle().await {
session.close().await?;
}
Ok(())
}
}