use crate::config::SessionConfig;
use crate::error::{AgentError, Result as AgentResult};
use crate::hooks::HookRegistry;
use crate::permissions::PermissionEvaluator;
use crate::routing::MessageRouter;
use crate::session::state::SessionState;
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use std::time::Duration;
use tokio::sync::Mutex;
use turboclaude_protocol::Message;
use turboclaude_transport::{CliTransport, ProcessConfig};
pub struct AgentSession {
pub(crate) transport: Arc<CliTransport>,
pub(crate) config: Arc<SessionConfig>,
pub(crate) hooks: Arc<HookRegistry>,
pub(crate) permissions: Arc<PermissionEvaluator>,
pub(crate) router: Arc<Mutex<Option<MessageRouter>>>,
pub(crate) state: Arc<Mutex<SessionState>>,
pub(crate) active_queries: Arc<AtomicU32>,
#[cfg(feature = "skills")]
pub(crate) skill_manager: Arc<tokio::sync::RwLock<Option<crate::skills::SkillManager>>>,
}
impl AgentSession {
pub async fn new(config: SessionConfig) -> AgentResult<Self> {
let process_config = ProcessConfig {
cli_path: config.cli_path.clone(),
..Default::default()
};
let transport = CliTransport::spawn(process_config)
.await
.map_err(|e| AgentError::Transport(format!("Failed to spawn CLI: {}", e)))?;
let transport = Arc::new(transport);
let hooks = Arc::new(HookRegistry::new());
let permissions = Arc::new(PermissionEvaluator::new(config.permission_mode));
let router = MessageRouter::new(
Arc::clone(&transport),
Arc::clone(&hooks),
Arc::clone(&permissions),
)
.await?;
let state = SessionState::new(config.default_model.clone(), config.permission_mode);
#[cfg(feature = "skills")]
let skill_manager = {
use turboclaude_skills::SkillRegistry;
let mut registry_builder = SkillRegistry::builder();
for dir in &config.skill_dirs {
registry_builder = registry_builder.skill_dir(dir.clone());
}
let registry = registry_builder.build().map_err(|e| {
AgentError::Config(format!("Failed to create skill registry: {}", e))
})?;
let manager = crate::skills::SkillManager::new(registry).await?;
Arc::new(tokio::sync::RwLock::new(Some(manager)))
};
Ok(Self {
transport,
config: Arc::new(config),
hooks,
permissions,
router: Arc::new(Mutex::new(Some(router))),
state: Arc::new(Mutex::new(state)),
active_queries: Arc::new(AtomicU32::new(0)),
#[cfg(feature = "skills")]
skill_manager,
})
}
pub async fn fork(&self) -> AgentResult<AgentSession> {
let history = {
let state = self.state.lock().await;
state.get_history()
};
let config = (*self.config).clone();
let forked = AgentSession::new(config).await?;
{
let mut forked_state = forked.state.lock().await;
for msg in history {
forked_state.add_to_history(msg);
}
}
{
let current_state = self.state.lock().await;
let mut forked_state = forked.state.lock().await;
forked_state.current_model = current_state.current_model.clone();
forked_state.current_permission_mode = current_state.current_permission_mode;
}
Ok(forked)
}
pub async fn state(&self) -> SessionState {
self.state.lock().await.clone()
}
pub async fn is_connected(&self) -> bool {
self.state.lock().await.is_connected
}
pub async fn close(&self) -> AgentResult<()> {
{
let mut state = self.state.lock().await;
state.is_connected = false;
}
{
let mut router_lock = self.router.lock().await;
if let Some(mut router) = router_lock.take() {
let _ = router.shutdown().await;
}
}
self.transport
.kill()
.await
.map_err(|e| AgentError::Transport(format!("Failed to kill transport: {}", e)))?;
Ok(())
}
pub(crate) async fn ensure_connected(&self) -> AgentResult<()> {
if self.transport.is_alive().await {
return Ok(());
}
let mut backoff = Duration::from_millis(500);
for attempt in 0..5 {
match self.reconnect().await {
Ok(_) => {
{
let mut state = self.state.lock().await;
state.is_connected = true;
}
return Ok(());
}
Err(_e) if attempt < 4 => {
tokio::time::sleep(backoff).await;
let backoff_millis = std::cmp::min(
backoff.as_millis() as u64 * 2,
Duration::from_secs(60).as_millis() as u64,
);
backoff = Duration::from_millis(backoff_millis);
}
Err(e) => {
return Err(e);
}
}
}
Err(AgentError::Transport(
"Failed to reconnect after 5 attempts".into(),
))
}
pub(crate) async fn reconnect(&self) -> AgentResult<()> {
let _ = self.transport.kill().await;
let process_config = ProcessConfig {
cli_path: self.config.cli_path.clone(),
..Default::default()
};
let _new_transport = CliTransport::spawn(process_config)
.await
.map_err(|e| AgentError::Transport(format!("Failed to spawn new CLI: {}", e)))?;
let new_router = MessageRouter::new(
Arc::clone(&self.transport),
Arc::clone(&self.hooks),
Arc::clone(&self.permissions),
)
.await?;
{
let mut router_lock = self.router.lock().await;
if let Some(mut old_router) = router_lock.take() {
let _ = old_router.shutdown().await;
}
*router_lock = Some(new_router);
}
Ok(())
}
#[allow(dead_code)]
pub(crate) async fn add_message_to_history(&self, message: Message) {
let mut state = self.state.lock().await;
state.add_to_history(message);
}
#[allow(dead_code)]
pub(crate) async fn get_conversation_history(&self) -> Vec<Message> {
let state = self.state.lock().await;
state.get_history()
}
}
#[cfg(test)]
mod tests {
use super::*;
use turboclaude_protocol::PermissionMode;
#[test]
fn test_session_state_new() {
let state = SessionState::new("claude-3-5-sonnet".to_string(), PermissionMode::Default);
assert!(state.is_connected);
assert_eq!(state.current_model, "claude-3-5-sonnet");
assert_eq!(state.current_permission_mode, PermissionMode::Default);
assert_eq!(state.active_queries, 0);
}
#[tokio::test]
async fn test_session_creation() {
let _config = SessionConfig::default();
}
#[test]
fn test_backoff_calculation() {
let mut backoff = Duration::from_millis(500);
for _ in 0..5 {
let next_millis = std::cmp::min(
backoff.as_millis() as u64 * 2,
Duration::from_secs(60).as_millis() as u64,
);
backoff = Duration::from_millis(next_millis);
}
assert!(backoff <= Duration::from_secs(60));
}
}