use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum ParticipantType {
#[default]
Human,
Agent,
Unknown,
}
impl ParticipantType {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
ParticipantType::Human => "human",
ParticipantType::Agent => "agent",
ParticipantType::Unknown => "unknown",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SpeechActType {
Continuer,
Acknowledgment,
Assessment,
Alignment,
BackChannel,
Question,
Statement,
Request,
Farewell,
Greeting,
Other,
}
impl SpeechActType {
#[must_use]
pub const fn is_response_token(&self) -> bool {
matches!(
self,
SpeechActType::Continuer
| SpeechActType::Acknowledgment
| SpeechActType::Assessment
| SpeechActType::Alignment
| SpeechActType::BackChannel
)
}
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
SpeechActType::Continuer => "continuer",
SpeechActType::Acknowledgment => "acknowledgment",
SpeechActType::Assessment => "assessment",
SpeechActType::Alignment => "alignment",
SpeechActType::BackChannel => "backchannel",
SpeechActType::Question => "question",
SpeechActType::Statement => "statement",
SpeechActType::Request => "request",
SpeechActType::Farewell => "farewell",
SpeechActType::Greeting => "greeting",
SpeechActType::Other => "other",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DialogueTurn {
pub text: String,
pub speaker: String,
pub participant_type: ParticipantType,
pub speech_act: Option<SpeechActType>,
pub is_aside: bool,
pub triggered_cutoff: bool,
pub turn_number: usize,
pub addressee: Option<String>,
pub language: Option<String>,
pub start: usize,
pub end: usize,
}
impl DialogueTurn {
#[must_use]
pub fn new(text: impl Into<String>, speaker: impl Into<String>) -> Self {
Self {
text: text.into(),
speaker: speaker.into(),
participant_type: ParticipantType::Unknown,
speech_act: None,
is_aside: false,
triggered_cutoff: false,
turn_number: 0,
addressee: None,
language: None,
start: 0,
end: 0,
}
}
#[must_use]
pub fn with_participant_type(mut self, pt: ParticipantType) -> Self {
self.participant_type = pt;
self
}
#[must_use]
pub fn with_speech_act(mut self, act: SpeechActType) -> Self {
self.speech_act = Some(act);
self
}
#[must_use]
pub fn as_aside(mut self, is_aside: bool) -> Self {
self.is_aside = is_aside;
self
}
#[must_use]
pub fn with_triggered_cutoff(mut self, triggered: bool) -> Self {
self.triggered_cutoff = triggered;
self
}
#[must_use]
pub fn with_turn_number(mut self, n: usize) -> Self {
self.turn_number = n;
self
}
#[must_use]
pub fn with_addressee(mut self, addr: impl Into<String>) -> Self {
self.addressee = Some(addr.into());
self
}
#[must_use]
pub fn with_language(mut self, lang: impl Into<String>) -> Self {
self.language = Some(lang.into());
self
}
#[must_use]
pub fn with_span(mut self, start: usize, end: usize) -> Self {
self.start = start;
self.end = end;
self
}
#[must_use]
pub fn is_response_token(&self) -> bool {
self.speech_act.is_some_and(|act| act.is_response_token())
}
#[must_use]
pub fn is_human(&self) -> bool {
matches!(self.participant_type, ParticipantType::Human)
}
#[must_use]
pub fn is_agent(&self) -> bool {
matches!(self.participant_type, ParticipantType::Agent)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DialogueContext {
pub turns: Vec<DialogueTurn>,
pub participants: Vec<String>,
pub current_addressee: Option<String>,
pub dialogue_id: Option<String>,
}
impl DialogueContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_turn(&mut self, mut turn: DialogueTurn) {
turn.turn_number = self.turns.len();
if !self.participants.contains(&turn.speaker) {
self.participants.push(turn.speaker.clone());
}
if let Some(ref addr) = turn.addressee {
self.current_addressee = Some(addr.clone());
}
self.turns.push(turn);
}
#[must_use]
pub fn last_turns(&self, n: usize) -> &[DialogueTurn] {
let start = self.turns.len().saturating_sub(n);
&self.turns[start..]
}
#[must_use]
pub fn turns_by_speaker(&self, speaker: &str) -> Vec<&DialogueTurn> {
self.turns.iter().filter(|t| t.speaker == speaker).collect()
}
#[must_use]
pub fn cutoff_count(&self) -> usize {
self.turns
.iter()
.filter(|t| t.is_response_token() && t.triggered_cutoff)
.count()
}
#[must_use]
pub fn aside_count(&self) -> usize {
self.turns.iter().filter(|t| t.is_aside).count()
}
#[must_use]
pub fn full_text(&self) -> String {
self.turns
.iter()
.map(|t| format!("{}: {}", t.speaker, t.text))
.collect::<Vec<_>>()
.join("\n")
}
}
#[must_use]
pub fn classify_response_token(token: &str, lang: Option<&str>) -> Option<SpeechActType> {
let lower = token.to_lowercase();
let lang = lang.unwrap_or("en");
match lang {
"fr" => match lower.as_str() {
"oui" | "ouais" | "mm" | "mhm" => Some(SpeechActType::Continuer),
"d'accord" | "ok" | "okai" | "okay" => Some(SpeechActType::Acknowledgment),
"ah" | "oh" | "wow" => Some(SpeechActType::Assessment),
"exactement" | "voilà" | "c'est ça" => Some(SpeechActType::Alignment),
"salut" | "bonjour" => Some(SpeechActType::Greeting),
"au revoir" | "à bientôt" => Some(SpeechActType::Farewell),
_ => None,
},
_ => match lower.as_str() {
"uh huh" | "mm-hmm" | "mm" | "mhm" | "yeah" => Some(SpeechActType::Continuer),
"okay" | "ok" | "got it" | "i see" => Some(SpeechActType::Acknowledgment),
"wow" | "really" | "oh" | "interesting" => Some(SpeechActType::Assessment),
"right" | "exactly" | "yes" => Some(SpeechActType::Alignment),
"hello" | "hi" | "hey" => Some(SpeechActType::Greeting),
"bye" | "goodbye" | "see you" => Some(SpeechActType::Farewell),
_ => None,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_participant_type() {
assert_eq!(ParticipantType::Human.as_str(), "human");
assert_eq!(ParticipantType::Agent.as_str(), "agent");
}
#[test]
fn test_speech_act_response_token() {
assert!(SpeechActType::Continuer.is_response_token());
assert!(SpeechActType::Acknowledgment.is_response_token());
assert!(SpeechActType::Assessment.is_response_token());
assert!(SpeechActType::Alignment.is_response_token());
assert!(SpeechActType::BackChannel.is_response_token());
assert!(!SpeechActType::Question.is_response_token());
assert!(!SpeechActType::Statement.is_response_token());
}
#[test]
fn test_dialogue_turn() {
let turn = DialogueTurn::new("oui", "EMM")
.with_participant_type(ParticipantType::Human)
.with_speech_act(SpeechActType::Continuer)
.with_triggered_cutoff(true);
assert!(turn.is_response_token());
assert!(turn.is_human());
assert!(!turn.is_agent());
assert!(turn.triggered_cutoff);
}
#[test]
fn test_aside() {
let turn = DialogueTurn::new("°tu n'es pas très intelligent°", "ANA")
.with_participant_type(ParticipantType::Human)
.as_aside(true);
assert!(turn.is_aside);
assert!(turn.is_human());
}
#[test]
fn test_dialogue_context() {
let mut ctx = DialogueContext::new();
ctx.add_turn(
DialogueTurn::new("Bonjour", "GPT")
.with_participant_type(ParticipantType::Agent)
.with_speech_act(SpeechActType::Greeting),
);
ctx.add_turn(
DialogueTurn::new("oui", "EMM")
.with_participant_type(ParticipantType::Human)
.with_speech_act(SpeechActType::Continuer)
.with_triggered_cutoff(true),
);
assert_eq!(ctx.turns.len(), 2);
assert_eq!(ctx.participants.len(), 2);
assert_eq!(ctx.cutoff_count(), 1);
}
#[test]
fn test_french_response_tokens() {
assert_eq!(
classify_response_token("oui", Some("fr")),
Some(SpeechActType::Continuer)
);
assert_eq!(
classify_response_token("d'accord", Some("fr")),
Some(SpeechActType::Acknowledgment)
);
assert_eq!(
classify_response_token("exactement", Some("fr")),
Some(SpeechActType::Alignment)
);
}
#[test]
fn test_english_response_tokens() {
assert_eq!(
classify_response_token("uh huh", None),
Some(SpeechActType::Continuer)
);
assert_eq!(
classify_response_token("okay", None),
Some(SpeechActType::Acknowledgment)
);
assert_eq!(
classify_response_token("exactly", None),
Some(SpeechActType::Alignment)
);
}
#[test]
fn classify_response_token_case_insensitive() {
assert_eq!(
classify_response_token("OUI", Some("fr")),
Some(SpeechActType::Continuer),
);
assert_eq!(
classify_response_token("Okay", None),
Some(SpeechActType::Acknowledgment),
);
assert_eq!(
classify_response_token("WOW", Some("fr")),
Some(SpeechActType::Assessment),
);
}
#[test]
fn classify_response_token_unknown_returns_none() {
assert_eq!(classify_response_token("blargfizzle", None), None);
assert_eq!(classify_response_token("xyzzy", Some("fr")), None);
assert_eq!(
classify_response_token("ok", Some("fr")),
Some(SpeechActType::Acknowledgment),
);
}
#[test]
fn classify_greetings_and_farewells() {
assert_eq!(
classify_response_token("hello", None),
Some(SpeechActType::Greeting),
);
assert_eq!(
classify_response_token("bye", None),
Some(SpeechActType::Farewell),
);
assert_eq!(
classify_response_token("bonjour", Some("fr")),
Some(SpeechActType::Greeting),
);
assert_eq!(
classify_response_token("au revoir", Some("fr")),
Some(SpeechActType::Farewell),
);
}
#[test]
fn speech_act_non_response_tokens() {
let non_response = [
SpeechActType::Question,
SpeechActType::Statement,
SpeechActType::Request,
SpeechActType::Farewell,
SpeechActType::Greeting,
SpeechActType::Other,
];
for act in non_response {
assert!(
!act.is_response_token(),
"{:?} should not be a response token",
act
);
}
}
#[test]
fn speech_act_as_str_roundtrip() {
let all = [
SpeechActType::Continuer,
SpeechActType::Acknowledgment,
SpeechActType::Assessment,
SpeechActType::Alignment,
SpeechActType::BackChannel,
SpeechActType::Question,
SpeechActType::Statement,
SpeechActType::Request,
SpeechActType::Farewell,
SpeechActType::Greeting,
SpeechActType::Other,
];
for act in all {
let label = act.as_str();
assert!(!label.is_empty(), "{:?} has empty label", act);
assert!(
label.chars().all(|c| c.is_ascii_lowercase()),
"{:?} label {:?} contains non-lowercase-ascii",
act,
label,
);
}
}
#[test]
fn dialogue_turn_no_speech_act_is_not_response_token() {
let turn = DialogueTurn::new("hello", "EMM");
assert!(!turn.is_response_token());
}
#[test]
fn dialogue_context_last_turns() {
let mut ctx = DialogueContext::new();
for i in 0..5 {
ctx.add_turn(DialogueTurn::new(format!("turn {i}"), "A"));
}
assert_eq!(ctx.last_turns(3).len(), 3);
assert_eq!(ctx.last_turns(3)[0].text, "turn 2");
assert_eq!(ctx.last_turns(10).len(), 5); assert_eq!(ctx.last_turns(0).len(), 0);
}
#[test]
fn dialogue_context_turns_by_speaker() {
let mut ctx = DialogueContext::new();
ctx.add_turn(DialogueTurn::new("hi", "A"));
ctx.add_turn(DialogueTurn::new("hey", "B"));
ctx.add_turn(DialogueTurn::new("sup", "A"));
assert_eq!(ctx.turns_by_speaker("A").len(), 2);
assert_eq!(ctx.turns_by_speaker("B").len(), 1);
assert_eq!(ctx.turns_by_speaker("C").len(), 0);
}
#[test]
fn dialogue_context_aside_count_and_full_text() {
let mut ctx = DialogueContext::new();
ctx.add_turn(DialogueTurn::new("hello", "A"));
ctx.add_turn(DialogueTurn::new("psst", "B").as_aside(true));
ctx.add_turn(DialogueTurn::new("what?", "A").as_aside(true));
assert_eq!(ctx.aside_count(), 2);
let text = ctx.full_text();
assert!(text.contains("A: hello"));
assert!(text.contains("B: psst"));
assert!(text.contains("A: what?"));
}
#[test]
fn dialogue_context_addressee_tracking() {
let mut ctx = DialogueContext::new();
ctx.add_turn(DialogueTurn::new("hi", "A").with_addressee("B"));
assert_eq!(ctx.current_addressee.as_deref(), Some("B"));
ctx.add_turn(DialogueTurn::new("ok", "B"));
assert_eq!(ctx.current_addressee.as_deref(), Some("B"));
ctx.add_turn(DialogueTurn::new("hey C", "A").with_addressee("C"));
assert_eq!(ctx.current_addressee.as_deref(), Some("C"));
}
}