use serde::{Deserialize, Serialize};
pub const PROTOCOL_VERSION: &str = "1.0";
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum ServerMessage {
Ready {
model: String,
sample_rate: u32,
version: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
supported_rates: Vec<u32>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
diarization: bool,
},
Partial {
text: String,
timestamp: f64,
#[serde(skip_serializing_if = "Vec::is_empty")]
words: Vec<crate::inference::WordInfo>,
},
Final {
text: String,
timestamp: f64,
#[serde(skip_serializing_if = "Vec::is_empty")]
words: Vec<crate::inference::WordInfo>,
},
Error {
message: String,
code: String,
#[serde(skip_serializing_if = "Option::is_none")]
retry_after_ms: Option<u32>,
},
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum ClientMessage {
Stop,
Configure {
#[serde(default)]
sample_rate: Option<u32>,
#[serde(default)]
diarization: Option<bool>,
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::inference::TARGET_SAMPLE_RATE;
#[test]
fn test_protocol_version_constant() {
assert_eq!(PROTOCOL_VERSION, "1.0");
}
#[test]
fn test_ready_serialization_includes_version() {
let msg = ServerMessage::Ready {
model: "test-model".into(),
sample_rate: 48000,
version: PROTOCOL_VERSION.into(),
supported_rates: vec![],
diarization: false,
};
let json = serde_json::to_string(&msg).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["type"], "ready");
assert_eq!(v["version"], "1.0");
assert_eq!(v["model"], "test-model");
assert_eq!(v["sample_rate"], 48000);
}
#[test]
fn test_partial_serialization_no_version() {
let msg = ServerMessage::Partial {
text: "hello".into(),
timestamp: 1.0,
words: vec![],
};
let json = serde_json::to_string(&msg).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["type"], "partial");
assert!(v.get("version").is_none());
}
#[test]
fn test_final_serialization_no_version() {
let msg = ServerMessage::Final {
text: "hello".into(),
timestamp: 1.0,
words: vec![],
};
let json = serde_json::to_string(&msg).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["type"], "final");
assert!(v.get("version").is_none());
}
#[test]
fn test_error_serialization_no_version() {
let msg = ServerMessage::Error {
message: "fail".into(),
code: "err".into(),
retry_after_ms: None,
};
let json = serde_json::to_string(&msg).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["type"], "error");
assert!(v.get("version").is_none());
assert!(
v.get("retry_after_ms").is_none(),
"retry_after_ms must be omitted when None"
);
}
#[test]
fn test_error_serialization_with_retry_after() {
let msg = ServerMessage::Error {
message: "Server busy, try again later".into(),
code: "timeout".into(),
retry_after_ms: Some(30_000),
};
let json = serde_json::to_string(&msg).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["type"], "error");
assert_eq!(v["code"], "timeout");
assert_eq!(v["retry_after_ms"], 30_000);
}
#[test]
fn test_client_message_stop_deserialize() {
let json = r#"{"type":"stop"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, ClientMessage::Stop));
}
#[test]
fn test_client_message_configure_deserialize() {
let json = r#"{"type":"configure","sample_rate":8000}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
match msg {
ClientMessage::Configure { sample_rate, .. } => assert_eq!(sample_rate, Some(8000)),
_ => panic!("Expected Configure"),
}
}
#[test]
fn test_ready_supported_rates_serialization() {
let msg = ServerMessage::Ready {
model: "test".into(),
sample_rate: 48000,
version: "1.0".into(),
supported_rates: vec![8000, TARGET_SAMPLE_RATE, 24000, 44100, 48000],
diarization: false,
};
let json = serde_json::to_string(&msg).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["supported_rates"].as_array().unwrap().len(), 5);
}
#[test]
fn test_ready_empty_supported_rates_omitted() {
let msg = ServerMessage::Ready {
model: "test".into(),
sample_rate: 48000,
version: "1.0".into(),
supported_rates: vec![],
diarization: false,
};
let json = serde_json::to_string(&msg).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(v.get("supported_rates").is_none());
}
#[test]
fn test_word_info_speaker_none_omitted() {
let word = crate::inference::WordInfo {
word: "hello".into(),
start: 0.0,
end: 1.0,
confidence: 0.9,
speaker: None,
};
let json = serde_json::to_string(&word).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(v.get("speaker").is_none());
}
#[test]
fn test_word_info_speaker_present() {
let word = crate::inference::WordInfo {
word: "hello".into(),
start: 0.0,
end: 1.0,
confidence: 0.9,
speaker: Some(2),
};
let json = serde_json::to_string(&word).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["speaker"], 2);
}
#[test]
fn test_configure_diarization_deserialize() {
let json = r#"{"type":"configure","diarization":true}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
match msg {
ClientMessage::Configure { diarization, .. } => assert_eq!(diarization, Some(true)),
_ => panic!("Expected Configure"),
}
}
#[test]
fn test_configure_sample_rate_only() {
let json = r#"{"type":"configure","sample_rate":8000}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
match msg {
ClientMessage::Configure {
sample_rate,
diarization,
} => {
assert_eq!(sample_rate, Some(8000));
assert_eq!(diarization, None);
}
_ => panic!("Expected Configure"),
}
}
#[test]
fn test_ready_diarization_false_omitted() {
let msg = ServerMessage::Ready {
model: "test".into(),
sample_rate: 48000,
version: "1.0".into(),
supported_rates: vec![],
diarization: false,
};
let json = serde_json::to_string(&msg).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(v.get("diarization").is_none());
}
}