use std::collections::VecDeque;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use llm::chat::{ChatMessage, ChatRole, Usage};
use llm::LLMProvider;
use super::events::{DialogueEvent, StopReason};
use super::participant::{ActiveParticipant, DialogueParticipant, ParticipantId};
use super::DialogueConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum TurnMode {
#[default]
RoundRobin,
Directed,
Free,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DialogueState {
#[default]
Idle,
Streaming,
WaitingNext,
Paused,
Stopped,
}
#[derive(Debug, Clone)]
pub struct DialogueTurn {
pub participant_id: ParticipantId,
pub content: String,
pub timestamp: DateTime<Utc>,
pub usage: Option<Usage>,
}
pub struct DialogueController {
participants: Vec<ActiveParticipant>,
config: DialogueConfig,
current_index: usize,
state: DialogueState,
user_message_queue: VecDeque<String>,
history: Vec<DialogueTurn>,
event_sender: Option<mpsc::UnboundedSender<DialogueEvent>>,
}
impl DialogueController {
#[must_use]
pub fn new(config: DialogueConfig) -> Self {
Self {
participants: Vec::new(),
config,
current_index: 0,
state: DialogueState::Idle,
user_message_queue: VecDeque::new(),
history: Vec::new(),
event_sender: None,
}
}
pub fn set_event_sender(&mut self, sender: mpsc::UnboundedSender<DialogueEvent>) {
self.event_sender = Some(sender);
}
#[must_use]
pub fn create_event_channel(&mut self) -> mpsc::UnboundedReceiver<DialogueEvent> {
let (tx, rx) = mpsc::unbounded_channel();
self.event_sender = Some(tx);
rx
}
#[must_use]
pub const fn state(&self) -> DialogueState {
self.state
}
#[must_use]
pub const fn config(&self) -> &DialogueConfig {
&self.config
}
#[must_use]
pub fn participants(&self) -> &[ActiveParticipant] {
&self.participants
}
#[must_use]
pub fn history(&self) -> &[DialogueTurn] {
&self.history
}
#[must_use]
pub fn active_participant_count(&self) -> usize {
self.participants.iter().filter(|p| p.config.active).count()
}
pub fn add_participant(
&mut self,
config: DialogueParticipant,
provider: Arc<dyn LLMProvider>,
) -> ParticipantId {
let id = config.id;
let name = config.display_name.clone();
self.participants
.push(ActiveParticipant { config, provider });
self.emit_event(DialogueEvent::ParticipantJoined {
participant_id: id,
participant_name: name,
});
id
}
pub fn remove_participant(&mut self, id: ParticipantId) -> bool {
if let Some(pos) = self.participants.iter().position(|p| p.config.id == id) {
self.participants.remove(pos);
self.emit_event(DialogueEvent::ParticipantLeft { participant_id: id });
if self.current_index >= self.participants.len() && !self.participants.is_empty() {
self.current_index = 0;
}
if self.active_participant_count() == 0 {
self.stop(StopReason::AllParticipantsLeft);
}
true
} else {
false
}
}
pub fn kick_participant(&mut self, id: ParticipantId) -> bool {
if let Some(participant) = self.participants.iter_mut().find(|p| p.config.id == id) {
participant.config.active = false;
self.emit_event(DialogueEvent::ParticipantLeft { participant_id: id });
if self.active_participant_count() == 0 {
self.stop(StopReason::AllParticipantsLeft);
}
true
} else {
false
}
}
pub fn invite_participant(&mut self, id: ParticipantId) -> bool {
if let Some(participant) = self.participants.iter_mut().find(|p| p.config.id == id) {
participant.config.active = true;
let name = participant.config.display_name.clone();
self.emit_event(DialogueEvent::ParticipantJoined {
participant_id: id,
participant_name: name,
});
true
} else {
false
}
}
pub fn inject_user_message(&mut self, content: String) {
self.user_message_queue.push_back(content.clone());
self.emit_event(DialogueEvent::UserMessage { content });
}
pub fn start(&mut self) {
if self.participants.is_empty() {
return;
}
self.state = DialogueState::WaitingNext;
self.emit_event(DialogueEvent::Started);
if !self.config.initial_prompt.is_empty() {
self.user_message_queue
.push_back(self.config.initial_prompt.clone());
}
}
pub fn stop(&mut self, reason: StopReason) {
self.state = DialogueState::Stopped;
self.emit_event(DialogueEvent::Stopped { reason });
}
pub fn pause(&mut self) {
if self.state == DialogueState::WaitingNext {
self.state = DialogueState::Paused;
}
}
pub fn resume(&mut self) {
if self.state == DialogueState::Paused {
self.state = DialogueState::WaitingNext;
}
}
#[must_use]
pub fn next_participant(&self) -> Option<&ActiveParticipant> {
if self.participants.is_empty() {
return None;
}
let start = self.current_index;
let mut idx = start;
loop {
if self.participants[idx].config.active {
return Some(&self.participants[idx]);
}
idx = (idx + 1) % self.participants.len();
if idx == start {
return None; }
}
}
pub fn advance_turn(&mut self) {
if self.participants.is_empty() {
return;
}
self.current_index = (self.current_index + 1) % self.participants.len();
let start = self.current_index;
while !self.participants[self.current_index].config.active {
self.current_index = (self.current_index + 1) % self.participants.len();
if self.current_index == start {
break;
}
}
}
pub fn record_turn(
&mut self,
participant_id: ParticipantId,
content: String,
usage: Option<Usage>,
) {
let turn = DialogueTurn {
participant_id,
content: content.clone(),
timestamp: Utc::now(),
usage: usage.clone(),
};
self.history.push(turn);
self.emit_event(DialogueEvent::TurnCompleted {
participant_id,
content,
usage,
});
}
pub fn drain_user_messages(&mut self) -> Vec<String> {
self.user_message_queue.drain(..).collect()
}
#[must_use]
pub fn build_context_messages(&self, participant: &ActiveParticipant) -> Vec<ChatMessage> {
let mut messages = Vec::new();
for turn in &self.history {
let role = if turn.participant_id == participant.config.id {
ChatRole::Assistant
} else {
ChatRole::User
};
let speaker_name = self
.participants
.iter()
.find(|p| p.config.id == turn.participant_id)
.map(|p| p.config.display_name.as_str())
.unwrap_or("Unknown");
let content = format!("[{}] {}", speaker_name, turn.content);
messages.push(ChatMessage {
role,
message_type: llm::chat::MessageType::Text,
content,
});
}
for msg in &self.user_message_queue {
messages.push(ChatMessage::user().content(msg.clone()).build());
}
messages
}
fn emit_event(&self, event: DialogueEvent) {
if let Some(sender) = &self.event_sender {
let _ = sender.send(event);
}
}
pub fn set_state(&mut self, state: DialogueState) {
self.state = state;
}
#[must_use]
pub fn get_participant(&self, id: ParticipantId) -> Option<&ActiveParticipant> {
self.participants.iter().find(|p| p.config.id == id)
}
#[must_use]
pub fn get_participant_by_index(&self, index: usize) -> Option<&ActiveParticipant> {
self.participants.get(index)
}
#[must_use]
pub fn find_participant_by_name(&self, name: &str) -> Option<&ActiveParticipant> {
self.participants
.iter()
.find(|p| p.config.display_name.eq_ignore_ascii_case(name))
}
}
impl std::fmt::Debug for DialogueController {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DialogueController")
.field("participants", &self.participants.len())
.field("state", &self.state)
.field("current_index", &self.current_index)
.field("history_len", &self.history.len())
.finish()
}
}