use crate::persistence::{start_disk_cleanup_task, start_persistence_processor, try_restore_session};
use crate::sequence_manager::SequenceManager;
use crate::session::spawn_session_actor;
use crate::types::{
PersistenceCommand, SessionCommand, SessionHandle, SessionStateSnapshot, ThoughtData,
};
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use kodegen_mcp_schema::sequential_thinking::SequentialThinkingArgs;
use kodegen_mcp_schema::McpError;
use std::sync::Arc;
use std::time::{Duration, Instant};
const PERSISTENCE_CHANNEL_CAPACITY: usize = 100;
const PERSISTENCE_BATCH_SIZE: usize = 50;
#[derive(Clone)]
pub struct SequentialThinkingTool {
sessions: Arc<DashMap<(String, u64), SessionHandle>>,
sequence_manager: Arc<SequenceManager>,
persistence_sender: tokio::sync::mpsc::Sender<PersistenceCommand>,
}
impl Default for SequentialThinkingTool {
fn default() -> Self {
Self::new()
}
}
impl SequentialThinkingTool {
#[must_use]
pub fn new() -> Self {
let (persistence_sender, persistence_receiver) =
tokio::sync::mpsc::channel(PERSISTENCE_CHANNEL_CAPACITY);
let tool = Self {
sessions: Arc::new(DashMap::new()),
sequence_manager: Arc::new(SequenceManager::new()),
persistence_sender: persistence_sender.clone(),
};
start_persistence_processor(persistence_receiver);
start_disk_cleanup_task(persistence_sender);
tool
}
pub async fn get_or_create_session(
&self,
connection_id: &str,
thought_number: u32,
) -> Result<(String, u32, tokio::sync::mpsc::Sender<SessionCommand>), McpError> {
let (sequence_id, is_new_sequence, previous_sequence) =
self.sequence_manager.resolve(connection_id, thought_number);
if is_new_sequence
&& let Some(prev_seq) = previous_sequence
{
let old_key = (connection_id.to_string(), prev_seq);
if let Some((_, old_handle)) = self.sessions.remove(&old_key) {
log::debug!(
"Cleaned up previous session for connection {} (sequence {})",
connection_id, prev_seq
);
drop(old_handle);
}
}
let key = (connection_id.to_string(), sequence_id);
let composite_id = format!("{}_{}", connection_id, sequence_id);
if let Some(entry) = self.sessions.get(&key) {
if !entry.value().tx.is_closed() {
entry.value().touch();
return Ok((composite_id, sequence_id as u32, entry.value().tx.clone()));
}
log::debug!("Session {} actor died, removing stale handle", composite_id);
drop(entry); self.sessions.remove(&key);
}
let maybe_restored = try_restore_session(&composite_id, &self.persistence_sender).await;
let handle = match self.sessions.entry(key.clone()) {
Entry::Occupied(entry) => {
log::debug!(
"Session {} already created by another thread, using existing",
composite_id
);
entry.get().touch();
entry.get().clone()
}
Entry::Vacant(entry) => {
let handle = if let Some(restored_handle) = maybe_restored {
log::info!("Restored session {} from disk", composite_id);
restored_handle
} else {
log::debug!("Creating new session {}", composite_id);
let (tx, rx) = tokio::sync::mpsc::channel::<SessionCommand>(100);
spawn_session_actor(rx);
SessionHandle::new(tx)
};
entry.insert(handle.clone());
handle
}
};
Ok((composite_id, sequence_id as u32, handle.tx.clone()))
}
pub fn cleanup_sequence(&self, connection_id: &str) {
if let Some(seq_id) = self.sequence_manager.current(connection_id) {
let key = (connection_id.to_string(), seq_id);
if self.sessions.remove(&key).is_some() {
log::debug!(
"Removed completed session for connection {} (sequence {})",
connection_id, seq_id
);
}
}
self.sequence_manager.cleanup(connection_id);
}
pub async fn get_session_state(
&self,
key: &(String, u64),
) -> Result<SessionStateSnapshot, McpError> {
let handle = self
.sessions
.get(key)
.ok_or_else(|| McpError::Other(anyhow::anyhow!("Session not found: {:?}", key)))?;
let (respond_to, rx) = tokio::sync::oneshot::channel();
let cmd = SessionCommand::GetState { respond_to };
handle
.value()
.tx
.send(cmd)
.await
.map_err(|_| McpError::Other(anyhow::anyhow!("Session actor terminated")))?;
rx.await
.map_err(|_| McpError::Other(anyhow::anyhow!("Failed to receive state")))
}
pub async fn clear_session(&self, key: &(String, u64)) -> Result<(), McpError> {
let handle = self
.sessions
.get(key)
.ok_or_else(|| McpError::Other(anyhow::anyhow!("Session not found: {:?}", key)))?;
let (respond_to, rx) = tokio::sync::oneshot::channel();
let cmd = SessionCommand::Clear { respond_to };
handle
.value()
.tx
.send(cmd)
.await
.map_err(|_| McpError::Other(anyhow::anyhow!("Session actor terminated")))?;
rx.await
.map_err(|_| McpError::Other(anyhow::anyhow!("Failed to clear session")))
}
pub fn get_session_info(&self, key: &(String, u64)) -> Result<(Instant, Instant), McpError> {
let handle = self
.sessions
.get(key)
.ok_or_else(|| McpError::Other(anyhow::anyhow!("Session not found: {:?}", key)))?;
let created_at = handle.value().created_at;
let last_activity = handle.value().last_activity();
Ok((created_at, last_activity))
}
async fn cleanup_sessions(&self, max_age: Duration) {
let purge_cutoff = Instant::now()
.checked_sub(max_age)
.unwrap_or_else(Instant::now);
let expired_entries: Vec<((String, u64), bool)> = self.sessions.iter()
.filter_map(|entry| {
let handle = entry.value();
if handle.tx.is_closed() {
log::debug!("Session {:?} has closed channel, marking for removal", entry.key());
return Some((entry.key().clone(), true)); }
let last_activity = handle.last_activity();
if last_activity < purge_cutoff {
log::debug!("Session {:?} expired, marking for persistence", entry.key());
return Some((entry.key().clone(), false)); }
None })
.collect();
if expired_entries.is_empty() {
return;
}
log::debug!("Found {} expired sessions to process", expired_entries.len());
let mut removed_count = 0usize;
let mut retry_count = 0usize;
for (key, is_dead) in expired_entries {
if is_dead {
if self.sessions.remove(&key).is_some() {
log::debug!("Removed dead session: {:?}", key);
removed_count += 1;
}
continue;
}
let persisted = self.try_persist_session(&key).await;
if persisted {
if self.sessions.remove(&key).is_some() {
log::debug!("Removed persisted session: {:?}", key);
removed_count += 1;
}
} else {
log::warn!(
"Failed to persist session {:?}, will retry next cleanup cycle",
key
);
retry_count += 1;
}
}
if removed_count > 0 || retry_count > 0 {
log::info!(
"Session cleanup: {} removed, {} deferred for retry",
removed_count, retry_count
);
}
}
async fn try_persist_session(&self, key: &(String, u64)) -> bool {
let Some(handle_ref) = self.sessions.get(key) else {
return true; };
if handle_ref.tx.is_closed() {
return true;
}
let tx = handle_ref.tx.clone();
let created_at_instant = handle_ref.created_at;
let last_activity_instant = handle_ref.last_activity();
drop(handle_ref);
let (respond_to, rx) = tokio::sync::oneshot::channel();
if tx.send(SessionCommand::GetState { respond_to }).await.is_err() {
return true;
}
let Ok(snapshot) = rx.await else {
return true;
};
let created_at_elapsed = created_at_instant.elapsed();
let created_at = std::time::SystemTime::now()
.checked_sub(created_at_elapsed)
.unwrap_or_else(std::time::SystemTime::now);
let last_activity_elapsed = last_activity_instant.elapsed();
let last_activity = std::time::SystemTime::now()
.checked_sub(last_activity_elapsed)
.unwrap_or_else(std::time::SystemTime::now);
let composite_id = format!("{}_{}", key.0, key.1);
match self.persistence_sender.try_send(PersistenceCommand::Persist {
session_id: composite_id.clone(),
snapshot,
created_at,
last_activity,
}) {
Ok(_) => {
log::debug!("Queued session {} for persistence", composite_id);
true }
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
log::warn!(
"Persistence channel full, deferring session {} to next cycle",
composite_id
);
false }
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
log::error!(
"Persistence channel closed, cannot persist session {}",
composite_id
);
false }
}
}
pub fn start_cleanup_task(self: Arc<Self>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5 * 60));
loop {
interval.tick().await;
self.cleanup_sessions(Duration::from_secs(30 * 60)).await;
}
});
}
pub async fn shutdown(&self) -> Result<(), McpError> {
log::info!("Shutting down sequential thinking tool, persisting active sessions");
let session_keys: Vec<(String, u64)> = self.sessions.iter()
.map(|entry| entry.key().clone())
.collect();
let total_sessions = session_keys.len();
log::debug!("Found {} active sessions to persist", total_sessions);
if total_sessions == 0 {
log::info!("No sessions to persist, shutdown complete");
return Ok(());
}
let mut batch_receivers: Vec<tokio::sync::oneshot::Receiver<Result<usize, String>>> = Vec::new();
let mut current_batch: Vec<(String, SessionStateSnapshot, std::time::SystemTime, std::time::SystemTime)> = Vec::new();
let mut sessions_queued = 0usize;
for key in session_keys {
let composite_id = format!("{}_{}", key.0, key.1);
if let Ok(snapshot) = self.get_session_state(&key).await {
if let Some(handle) = self.sessions.get(&key) {
let created_at_elapsed = handle.created_at.elapsed();
let created_at = std::time::SystemTime::now()
.checked_sub(created_at_elapsed)
.unwrap_or_else(std::time::SystemTime::now);
let last_activity_instant = handle.last_activity();
let last_activity_elapsed = last_activity_instant.elapsed();
let last_activity = std::time::SystemTime::now()
.checked_sub(last_activity_elapsed)
.unwrap_or_else(std::time::SystemTime::now);
current_batch.push((composite_id.clone(), snapshot, created_at, last_activity));
sessions_queued += 1;
if current_batch.len() >= PERSISTENCE_BATCH_SIZE {
let (tx, rx) = tokio::sync::oneshot::channel();
if let Err(e) = self.persistence_sender.send(
PersistenceCommand::PersistBatch {
sessions: std::mem::take(&mut current_batch),
completion: Some(tx),
}
).await {
log::error!("Failed to send persistence batch: {e}");
break;
}
batch_receivers.push(rx);
}
}
}
}
if !current_batch.is_empty() {
let (tx, rx) = tokio::sync::oneshot::channel();
if let Err(e) = self.persistence_sender.send(
PersistenceCommand::PersistBatch {
sessions: current_batch,
completion: Some(tx),
}
).await {
log::error!("Failed to send final persistence batch: {e}");
} else {
batch_receivers.push(rx);
}
}
let batch_count = batch_receivers.len();
log::info!(
"Queued {} sessions in {} batches, waiting for completion",
sessions_queued, batch_count
);
const BATCH_TIMEOUT: Duration = Duration::from_secs(30);
let mut total_persisted = 0usize;
let mut batches_succeeded = 0usize;
let mut batches_failed = 0usize;
let mut batches_timeout = 0usize;
for (batch_num, rx) in batch_receivers.into_iter().enumerate() {
match tokio::time::timeout(BATCH_TIMEOUT, rx).await {
Ok(Ok(Ok(count))) => {
log::debug!("Batch {} completed: {} sessions persisted", batch_num, count);
total_persisted += count;
batches_succeeded += 1;
}
Ok(Ok(Err(e))) => {
log::error!("Batch {} reported failure: {}", batch_num, e);
batches_failed += 1;
}
Ok(Err(_)) => {
log::error!("Batch {} completion channel dropped (task crashed?)", batch_num);
batches_failed += 1;
}
Err(_) => {
log::error!(
"Batch {} timeout after {:?} (persistence may still complete)",
batch_num, BATCH_TIMEOUT
);
batches_timeout += 1;
}
}
}
log::info!(
"Sequential thinking shutdown: {}/{} sessions persisted, \
batches: {} succeeded, {} failed, {} timeout",
total_persisted, sessions_queued,
batches_succeeded, batches_failed, batches_timeout
);
Ok(())
}
pub fn validate_thought(args: SequentialThinkingArgs) -> ThoughtData {
let total_thoughts = args.total_thoughts.max(args.thought_number);
ThoughtData {
thought: args.thought,
thought_number: args.thought_number,
total_thoughts,
next_thought_needed: args.next_thought_needed,
is_revision: args.is_revision,
revises_thought: args.revises_thought,
branch_from_thought: args.branch_from_thought,
branch_id: args.branch_id,
needs_more_thoughts: args.needs_more_thoughts,
}
}
}