use rust_genai::types::content::{Content, FunctionResponse, Part, Role};
use rust_genai::types::enums::Modality;
use rust_genai::types::live_types::{
AudioTranscriptionConfig, LiveConnectConfig, LiveSendClientContentParameters,
LiveSendToolResponseParameters,
};
use rust_genai::{Client, Result};
use serde_json::json;
use std::fs::File;
use std::io::{Seek, SeekFrom, Write};
#[tokio::main]
async fn main() -> Result<()> {
run().await
}
async fn run() -> Result<()> {
let client = Client::from_env()?;
let model = "gemini-3.1-flash-live-preview";
let config = build_live_config()?;
println!("连接 Live API 中... (model={model})");
let mut session = client.live().connect(model, config).await?;
println!("连接成功。发送请求...");
send_weather_prompt(&session).await?;
println!("请求已发送,等待响应...");
handle_session(&mut session).await?;
session.close().await?;
Ok(())
}
fn build_live_config() -> Result<LiveConnectConfig> {
let get_weather = json!({
"name": "get_weather",
"description": "Get the current weather for a city",
"parameters": {
"type": "OBJECT",
"properties": {
"city": {
"type": "STRING",
"description": "The city name"
}
},
"required": ["city"]
}
});
let tools = json!([{
"functionDeclarations": [get_weather]
}]);
Ok(LiveConnectConfig {
response_modalities: Some(vec![Modality::Audio]),
output_audio_transcription: Some(AudioTranscriptionConfig::default()),
tools: Some(serde_json::from_value(tools)?),
..Default::default()
})
}
async fn send_weather_prompt(session: &rust_genai::live::LiveSession) -> Result<()> {
let prompt = "What's the weather like in Beijing?";
session
.send_client_content(LiveSendClientContentParameters {
turns: Some(vec![Content {
role: Some(Role::User),
parts: vec![Part::text(prompt)],
}]),
turn_complete: Some(true),
})
.await
}
async fn handle_session(session: &mut rust_genai::live::LiveSession) -> Result<()> {
let mut state = SessionState::new();
let audio_out_path = std::env::var("GENAI_AUDIO_OUT_PATH").ok();
let deadline = std::time::Duration::from_secs(30);
loop {
let message = match tokio::time::timeout(deadline, session.receive()).await {
Ok(Some(msg)) => {
let msg = msg?;
if msg.tool_call.is_some() {
println!();
}
msg
}
Ok(None) => {
println!("\n\n会话结束。");
break;
}
Err(_) => {
println!("\n等待响应超时。");
break;
}
};
if let Some(tool_call) = message.tool_call.as_ref() {
handle_tool_calls(session, tool_call).await?;
}
let server_content = message.server_content.as_ref();
let has_transcription = server_content
.and_then(|c| c.output_transcription.as_ref())
.is_some();
handle_transcription(server_content, &mut state);
handle_model_turn(
server_content,
&mut state,
audio_out_path.as_deref(),
has_transcription,
)?;
if server_content
.and_then(|c| c.turn_complete)
.unwrap_or(false)
{
finalize_turn(&mut state, audio_out_path.as_deref())?;
break;
}
}
Ok(())
}
async fn handle_tool_calls(
session: &rust_genai::live::LiveSession,
tool_call: &rust_genai::types::live_types::LiveServerToolCall,
) -> Result<()> {
let Some(function_calls) = tool_call.function_calls.as_ref() else {
return Ok(());
};
if function_calls.is_empty() {
return Ok(());
}
let responses = function_calls
.iter()
.map(|call| {
let func_name = call.name.as_deref().unwrap_or("unknown");
let args_str = serde_json::to_string(&call.args).unwrap_or_default();
println!("[function_call] {func_name}({args_str})");
FunctionResponse {
will_continue: None,
scheduling: None,
parts: None,
id: call.id.clone(),
name: call.name.clone(),
response: Some(json!({"temperature": "20C", "condition": "sunny"})),
}
})
.collect::<Vec<_>>();
session
.send_tool_response(LiveSendToolResponseParameters {
function_responses: Some(responses),
})
.await
}
fn handle_transcription(
server_content: Option<&rust_genai::types::live_types::LiveServerContent>,
state: &mut SessionState,
) {
if let Some(text) = server_content
.and_then(|content| content.output_transcription.as_ref())
.and_then(|transcription| transcription.text.as_deref())
{
emit_text(text, state);
}
}
fn handle_model_turn(
server_content: Option<&rust_genai::types::live_types::LiveServerContent>,
state: &mut SessionState,
audio_out_path: Option<&str>,
has_transcription: bool,
) -> Result<()> {
let Some(content) = server_content.and_then(|content| content.model_turn.as_ref()) else {
return Ok(());
};
for part in &content.parts {
if part.thought.unwrap_or(false) {
continue;
}
match &part.kind {
rust_genai::types::content::PartKind::Text { text } if !has_transcription => {
emit_text(text, state);
}
rust_genai::types::content::PartKind::InlineData { inline_data } => {
handle_inline_data(inline_data, state, audio_out_path)?;
}
_ => {}
}
}
Ok(())
}
fn handle_inline_data(
inline_data: &rust_genai::types::content::Blob,
state: &mut SessionState,
audio_out_path: Option<&str>,
) -> Result<()> {
if !inline_data.mime_type.starts_with("audio/") {
return Ok(());
}
let Some(path) = audio_out_path else {
return Ok(());
};
let rate = parse_sample_rate(&inline_data.mime_type).unwrap_or(24_000);
if state.wav_writer.is_none() {
state.wav_writer = Some(WavWriter::create(path, rate)?);
}
if let Some(writer) = state.wav_writer.as_mut() {
writer.write_chunk(&inline_data.data)?;
}
Ok(())
}
fn finalize_turn(state: &mut SessionState, audio_out_path: Option<&str>) -> Result<()> {
if state.text_started {
println!();
}
if let Some(writer) = state.wav_writer.as_mut() {
writer.update_header()?;
if let Some(path) = audio_out_path {
let rate = writer.sample_rate;
println!("[audio] 已保存到 {path} (rate={rate}Hz)");
}
}
Ok(())
}
struct SessionState {
text_started: bool,
last_char: Option<char>,
wav_writer: Option<WavWriter>,
}
impl SessionState {
const fn new() -> Self {
Self {
text_started: false,
last_char: None,
wav_writer: None,
}
}
}
fn emit_text(text: &str, state: &mut SessionState) {
let trimmed = text.trim();
if trimmed.is_empty() {
return;
}
if !state.text_started {
print!("assistant: ");
state.text_started = true;
} else if let Some(first_char) = trimmed.chars().next() {
if text.starts_with(char::is_whitespace) && needs_space_before(state.last_char, first_char)
{
print!(" ");
}
}
print!("{trimmed}");
std::io::stdout().flush().ok();
state.last_char = trimmed.chars().last();
}
fn parse_sample_rate(mime_type: &str) -> Option<u32> {
mime_type
.split(';')
.find_map(|part| part.trim().strip_prefix("rate="))
.and_then(|value| value.parse::<u32>().ok())
}
struct WavWriter {
file: File,
data_len: u32,
sample_rate: u32,
channels: u16,
bits_per_sample: u16,
}
impl WavWriter {
fn create(path: &str, sample_rate: u32) -> rust_genai::Result<Self> {
let file = File::create(path)?;
let mut writer = Self {
file,
data_len: 0,
sample_rate,
channels: 1,
bits_per_sample: 16,
};
writer.write_header()?;
Ok(writer)
}
fn write_chunk(&mut self, data: &[u8]) -> rust_genai::Result<()> {
self.file.write_all(data)?;
let chunk_len =
u32::try_from(data.len()).map_err(|_| rust_genai::Error::InvalidConfig {
message: "audio chunk too large".into(),
})?;
self.data_len = self.data_len.saturating_add(chunk_len);
Ok(())
}
fn write_header(&mut self) -> rust_genai::Result<()> {
self.file.seek(SeekFrom::Start(0))?;
let byte_rate =
self.sample_rate * u32::from(self.channels) * u32::from(self.bits_per_sample) / 8;
let block_align = self.channels * (self.bits_per_sample / 8);
let chunk_size = 36u32.saturating_add(self.data_len);
self.file.write_all(b"RIFF")?;
self.file.write_all(&chunk_size.to_le_bytes())?;
self.file.write_all(b"WAVE")?;
self.file.write_all(b"fmt ")?;
self.file.write_all(&16u32.to_le_bytes())?;
self.file.write_all(&1u16.to_le_bytes())?;
self.file.write_all(&self.channels.to_le_bytes())?;
self.file.write_all(&self.sample_rate.to_le_bytes())?;
self.file.write_all(&byte_rate.to_le_bytes())?;
self.file.write_all(&block_align.to_le_bytes())?;
self.file.write_all(&self.bits_per_sample.to_le_bytes())?;
self.file.write_all(b"data")?;
self.file.write_all(&self.data_len.to_le_bytes())?;
self.file.seek(SeekFrom::End(0))?;
Ok(())
}
fn update_header(&mut self) -> rust_genai::Result<()> {
self.write_header()
}
}
fn needs_space_before(last: Option<char>, current_first: char) -> bool {
let Some(last_char) = last else {
return false;
};
let is_cjk = |c: char| -> bool {
matches!(c,
'\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{20000}'..='\u{2A6DF}' | '\u{2A700}'..='\u{2B73F}' | '\u{2B740}'..='\u{2B81F}' | '\u{2B820}'..='\u{2CEAF}' | '\u{3000}'..='\u{303F}' | '\u{FF00}'..='\u{FFEF}' | '\u{3040}'..='\u{309F}' | '\u{30A0}'..='\u{30FF}' )
};
let is_punctuation = |c: char| -> bool {
matches!(
c,
'.' | ','
| '!'
| '?'
| ';'
| ':'
| ')'
| ']'
| '}'
| '\''
| '。'
| ','
| '!'
| '?'
| ';'
| ':'
| ')'
| '】'
| '』'
| '"'
| '\u{2019}'
)
};
if is_cjk(current_first) {
if is_cjk(last_char) {
return false;
}
if last_char.is_alphanumeric() {
return true;
}
}
if is_punctuation(current_first) {
return false;
}
if is_punctuation(last_char) && current_first.is_alphanumeric() {
return true;
}
if current_first.is_alphanumeric() && last_char.is_alphanumeric() {
return true;
}
false
}