use crate::product::agent::AuthManager;
#[cfg(any(test, feature = "test-support"))]
use crate::product::agent::CodexAuth;
use crate::product::agent::codex::Codex;
use crate::product::agent::codex::CodexSpawnOk;
use crate::product::agent::codex::INITIAL_SUBMIT_ID;
use crate::product::agent::codex_thread::CodexThread;
use crate::product::agent::config::Config;
use crate::product::agent::error::CodexErr;
use crate::product::agent::error::Result as CodexResult;
use crate::product::agent::models_manager::manager::ModelsManager;
use crate::product::agent::protocol::Event;
use crate::product::agent::protocol::EventMsg;
use crate::product::agent::protocol::SessionConfiguredEvent;
use crate::product::agent::rollout::RolloutRecorder;
use crate::product::agent::rollout::truncation;
use crate::product::agent::skills::SkillsManager;
use crate::product::protocol::ThreadId;
use crate::product::protocol::config_types::IdentityMask;
use crate::product::protocol::openai_models::ModelInfo;
use crate::product::protocol::openai_models::ModelPreset;
use crate::product::protocol::protocol::InitialHistory;
use crate::product::protocol::protocol::McpServerRefreshConfig;
use crate::product::protocol::protocol::Op;
use crate::product::protocol::protocol::RolloutItem;
use crate::product::protocol::protocol::SessionSource;
use lha_llm::CatalogRefreshStrategy;
use lha_llm::RuntimeEndpoint;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
#[cfg(any(test, feature = "test-support"))]
use tempfile::TempDir;
use tokio::sync::RwLock;
use tokio::sync::TryLockError;
use tokio::sync::broadcast;
use tracing::warn;
const THREAD_CREATED_CHANNEL_CAPACITY: usize = 1024;
pub struct NewThread {
pub thread_id: ThreadId,
pub thread: Arc<CodexThread>,
pub session_configured: SessionConfiguredEvent,
}
pub struct ThreadManager {
state: Arc<ThreadManagerState>,
#[cfg(any(test, feature = "test-support"))]
_test_lha_home_guard: Option<TempDir>,
}
impl Clone for ThreadManager {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
#[cfg(any(test, feature = "test-support"))]
_test_lha_home_guard: None,
}
}
}
pub(crate) struct ThreadManagerState {
threads: Arc<RwLock<HashMap<ThreadId, Arc<CodexThread>>>>,
thread_created_tx: broadcast::Sender<ThreadId>,
auth_manager: Arc<AuthManager>,
models_manager: Arc<ModelsManager>,
skills_manager: Arc<SkillsManager>,
session_source: SessionSource,
#[cfg(any(test, feature = "test-support"))]
#[allow(dead_code)]
ops_log: Arc<std::sync::Mutex<Vec<(ThreadId, Op)>>>,
}
impl ThreadManager {
pub fn new(
lha_home: PathBuf,
auth_manager: Arc<AuthManager>,
model_provider_id: &str,
provider: RuntimeEndpoint,
session_source: SessionSource,
) -> Self {
let (thread_created_tx, _) = broadcast::channel(THREAD_CREATED_CHANNEL_CAPACITY);
Self {
state: Arc::new(ThreadManagerState {
threads: Arc::new(RwLock::new(HashMap::new())),
thread_created_tx,
models_manager: Arc::new(ModelsManager::new(
lha_home.clone(),
auth_manager.clone(),
model_provider_id,
provider,
)),
skills_manager: Arc::new(SkillsManager::new(lha_home)),
auth_manager,
session_source,
#[cfg(any(test, feature = "test-support"))]
ops_log: Arc::new(std::sync::Mutex::new(Vec::new())),
}),
#[cfg(any(test, feature = "test-support"))]
_test_lha_home_guard: None,
}
}
#[cfg(any(test, feature = "test-support"))]
pub fn with_models_provider(auth: CodexAuth, provider: RuntimeEndpoint) -> Self {
let temp_dir = tempfile::tempdir().unwrap_or_else(|err| panic!("temp codex home: {err}"));
let lha_home = temp_dir.path().to_path_buf();
let mut manager =
Self::with_models_provider_and_home(auth, "test-provider", provider, lha_home);
manager._test_lha_home_guard = Some(temp_dir);
manager
}
#[cfg(any(test, feature = "test-support"))]
pub fn with_models_provider_and_home(
auth: CodexAuth,
model_provider_id: &str,
provider: RuntimeEndpoint,
lha_home: PathBuf,
) -> Self {
let auth_manager = AuthManager::from_auth_for_testing(auth);
let (thread_created_tx, _) = broadcast::channel(THREAD_CREATED_CHANNEL_CAPACITY);
Self {
state: Arc::new(ThreadManagerState {
threads: Arc::new(RwLock::new(HashMap::new())),
thread_created_tx,
models_manager: Arc::new(ModelsManager::with_provider(
lha_home.clone(),
auth_manager.clone(),
model_provider_id,
provider,
)),
skills_manager: Arc::new(SkillsManager::new(lha_home)),
auth_manager,
session_source: SessionSource::Exec,
#[cfg(any(test, feature = "test-support"))]
ops_log: Arc::new(std::sync::Mutex::new(Vec::new())),
}),
_test_lha_home_guard: None,
}
}
pub fn session_source(&self) -> SessionSource {
self.state.session_source.clone()
}
pub fn skills_manager(&self) -> Arc<SkillsManager> {
self.state.skills_manager.clone()
}
pub fn get_models_manager(&self) -> Arc<ModelsManager> {
self.state.models_manager.clone()
}
pub async fn list_models(
&self,
config: &Config,
refresh_strategy: CatalogRefreshStrategy,
) -> Vec<ModelPreset> {
self.state
.models_manager
.list_models(config, refresh_strategy)
.await
}
pub async fn list_picker_models(
&self,
config: &Config,
refresh_strategy: CatalogRefreshStrategy,
) -> Vec<ModelPreset> {
self.state
.models_manager
.list_picker_models(config, refresh_strategy)
.await
}
pub fn try_list_models(&self, config: &Config) -> Result<Vec<ModelPreset>, TryLockError> {
self.state.models_manager.try_list_models(config)
}
pub fn try_list_picker_models(
&self,
config: &Config,
) -> Result<Vec<ModelPreset>, TryLockError> {
self.state.models_manager.try_list_picker_models(config)
}
pub async fn list_model_switcher_models(
&self,
config: &Config,
refresh_strategy: CatalogRefreshStrategy,
) -> Vec<ModelPreset> {
self.state
.models_manager
.list_model_switcher_models(config, refresh_strategy)
.await
}
pub fn try_list_model_switcher_models(
&self,
config: &Config,
) -> Result<Vec<ModelPreset>, TryLockError> {
self.state
.models_manager
.try_list_model_switcher_models(config)
}
pub fn list_identities(&self) -> Vec<IdentityMask> {
self.state.models_manager.list_identities()
}
pub fn try_is_official_openai_model(
&self,
config: &Config,
model: &str,
model_provider_id: &str,
) -> Result<bool, TryLockError> {
self.state
.models_manager
.try_is_official_openai_model(config, model, model_provider_id)
}
pub async fn get_default_model(
&self,
model: &Option<String>,
config: &Config,
refresh_strategy: CatalogRefreshStrategy,
) -> CodexResult<String> {
self.state
.models_manager
.get_default_model(model, config, refresh_strategy)
.await
}
pub async fn get_model_info(&self, model: &str, config: &Config) -> ModelInfo {
self.state
.models_manager
.get_model_info(model, config)
.await
}
pub async fn switch_model_provider(&self, model_provider_id: &str, provider: RuntimeEndpoint) {
self.state
.models_manager
.switch_provider(model_provider_id, provider)
.await;
}
pub async fn list_thread_ids(&self) -> Vec<ThreadId> {
self.state.threads.read().await.keys().copied().collect()
}
pub async fn refresh_mcp_servers(&self, refresh_config: McpServerRefreshConfig) {
let threads = self
.state
.threads
.read()
.await
.values()
.cloned()
.collect::<Vec<_>>();
for thread in threads {
if let Err(err) = thread
.submit(Op::RefreshMcpServers {
config: refresh_config.clone(),
})
.await
{
warn!("failed to request MCP server refresh: {err}");
}
}
}
pub fn subscribe_thread_created(&self) -> broadcast::Receiver<ThreadId> {
self.state.thread_created_tx.subscribe()
}
pub async fn get_thread(&self, thread_id: ThreadId) -> CodexResult<Arc<CodexThread>> {
self.state.get_thread(thread_id).await
}
pub async fn start_thread(&self, config: Config) -> CodexResult<NewThread> {
self.start_thread_with_tools(config, Vec::new()).await
}
pub async fn start_thread_with_tools(
&self,
config: Config,
dynamic_tools: Vec<crate::product::protocol::dynamic_tools::DynamicToolSpec>,
) -> CodexResult<NewThread> {
self.state
.spawn_thread(
config,
InitialHistory::New,
Arc::clone(&self.state.auth_manager),
dynamic_tools,
)
.await
}
pub async fn resume_thread_from_rollout(
&self,
config: Config,
rollout_path: PathBuf,
auth_manager: Arc<AuthManager>,
) -> CodexResult<NewThread> {
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
self.resume_thread_with_history(config, initial_history, auth_manager)
.await
}
pub async fn resume_thread_with_history(
&self,
config: Config,
initial_history: InitialHistory,
auth_manager: Arc<AuthManager>,
) -> CodexResult<NewThread> {
self.state
.spawn_thread(config, initial_history, auth_manager, Vec::new())
.await
}
pub async fn remove_thread(&self, thread_id: &ThreadId) -> Option<Arc<CodexThread>> {
self.state.remove_thread(thread_id).await
}
pub async fn remove_and_close_all_threads(&self) -> CodexResult<()> {
for thread in self.state.threads.read().await.values() {
thread.submit(Op::Shutdown).await?;
}
self.state.threads.write().await.clear();
Ok(())
}
pub async fn fork_thread(
&self,
nth_user_message: usize,
config: Config,
path: PathBuf,
) -> CodexResult<NewThread> {
let history = RolloutRecorder::get_rollout_history(&path).await?;
let history = truncate_before_nth_user_message(history, nth_user_message);
self.state
.spawn_thread(
config,
history,
Arc::clone(&self.state.auth_manager),
Vec::new(),
)
.await
}
pub async fn fork_thread_with_source(
&self,
nth_user_message: usize,
config: Config,
path: PathBuf,
session_source: SessionSource,
) -> CodexResult<NewThread> {
let history = RolloutRecorder::get_rollout_history(&path).await?;
let history = truncate_before_nth_user_message(history, nth_user_message);
self.state
.spawn_thread_with_source(
config,
history,
Arc::clone(&self.state.auth_manager),
session_source,
Vec::new(),
)
.await
}
#[cfg(any(test, feature = "test-support"))]
#[allow(dead_code)]
pub(crate) fn captured_ops(&self) -> Vec<(ThreadId, Op)> {
self.state
.ops_log
.lock()
.map(|log| log.clone())
.unwrap_or_default()
}
}
impl ThreadManagerState {
pub(crate) async fn get_thread(&self, thread_id: ThreadId) -> CodexResult<Arc<CodexThread>> {
let threads = self.threads.read().await;
threads
.get(&thread_id)
.cloned()
.ok_or_else(|| CodexErr::ThreadNotFound(thread_id))
}
pub(crate) async fn remove_thread(&self, thread_id: &ThreadId) -> Option<Arc<CodexThread>> {
self.threads.write().await.remove(thread_id)
}
pub(crate) async fn spawn_thread(
&self,
config: Config,
initial_history: InitialHistory,
auth_manager: Arc<AuthManager>,
dynamic_tools: Vec<crate::product::protocol::dynamic_tools::DynamicToolSpec>,
) -> CodexResult<NewThread> {
self.spawn_thread_with_source(
config,
initial_history,
auth_manager,
self.session_source.clone(),
dynamic_tools,
)
.await
}
pub(crate) async fn spawn_thread_with_source(
&self,
config: Config,
initial_history: InitialHistory,
auth_manager: Arc<AuthManager>,
session_source: SessionSource,
dynamic_tools: Vec<crate::product::protocol::dynamic_tools::DynamicToolSpec>,
) -> CodexResult<NewThread> {
let startup_config = config.clone();
let startup_source = session_source.clone();
let CodexSpawnOk {
codex, thread_id, ..
} = Codex::spawn(
config,
auth_manager,
Arc::clone(&self.models_manager),
Arc::clone(&self.skills_manager),
initial_history,
session_source,
dynamic_tools,
)
.await?;
let new_thread = self.finalize_thread_spawn(codex, thread_id).await?;
crate::product::agent::memories::startup::start_memories_startup_task(
Arc::clone(&self.auth_manager),
Arc::clone(&self.models_manager),
Arc::clone(&self.skills_manager),
startup_config,
new_thread.thread_id,
Arc::clone(&new_thread.thread),
startup_source,
);
Ok(new_thread)
}
async fn finalize_thread_spawn(
&self,
codex: Codex,
thread_id: ThreadId,
) -> CodexResult<NewThread> {
let event = codex.next_event().await?;
let session_configured = match event {
Event {
id,
msg: EventMsg::SessionConfigured(session_configured),
} if id == INITIAL_SUBMIT_ID => session_configured,
_ => {
return Err(CodexErr::SessionConfiguredNotFirstEvent);
}
};
let thread = Arc::new(CodexThread::new(
codex,
session_configured.rollout_path.clone(),
));
self.threads.write().await.insert(thread_id, thread.clone());
let _ = self.thread_created_tx.send(thread_id);
Ok(NewThread {
thread_id,
thread,
session_configured,
})
}
}
fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> InitialHistory {
let items: Vec<RolloutItem> = history.get_rollout_items();
let rolled = truncation::truncate_rollout_before_nth_user_message_from_start(&items, n);
if rolled.is_empty() {
InitialHistory::New
} else {
InitialHistory::Forked(rolled)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::product::agent::codex::make_session_and_context;
use crate::product::protocol::models::ContentItem;
use crate::product::protocol::models::ReasoningItemReasoningSummary;
use crate::product::protocol::models::TranscriptItem;
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
fn user_msg(text: &str) -> TranscriptItem {
TranscriptItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
end_turn: None,
}
}
fn assistant_msg(text: &str) -> TranscriptItem {
TranscriptItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
end_turn: None,
}
}
#[test]
fn drops_from_last_user_only() {
let items = [
user_msg("u1"),
assistant_msg("a1"),
assistant_msg("a2"),
user_msg("u2"),
assistant_msg("a3"),
TranscriptItem::Reasoning {
id: "r1".to_string(),
summary: vec![ReasoningItemReasoningSummary::SummaryText {
text: "s".to_string(),
}],
content: None,
encrypted_content: None,
},
TranscriptItem::ToolCall {
id: None,
call_id: "c1".to_string(),
tool_name: "tool".to_string(),
payload: lha_llm::ToolCallPayload::JsonArguments {
arguments: "{}".to_string(),
},
},
assistant_msg("a4"),
];
let initial: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::TranscriptItem)
.collect();
let truncated = truncate_before_nth_user_message(InitialHistory::Forked(initial), 1);
let got_items = truncated.get_rollout_items();
let expected_items = vec![
RolloutItem::TranscriptItem(items[0].clone()),
RolloutItem::TranscriptItem(items[1].clone()),
RolloutItem::TranscriptItem(items[2].clone()),
];
assert_eq!(
serde_json::to_value(&got_items).unwrap(),
serde_json::to_value(&expected_items).unwrap()
);
let initial2: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::TranscriptItem)
.collect();
let truncated2 = truncate_before_nth_user_message(InitialHistory::Forked(initial2), 2);
assert_matches!(truncated2, InitialHistory::New);
}
#[tokio::test]
async fn ignores_session_prefix_messages_when_truncating() {
let (session, turn_context) = make_session_and_context().await;
let mut items = session.build_initial_context(&turn_context).await;
items.push(user_msg("feature request"));
items.push(assistant_msg("ack"));
items.push(user_msg("second question"));
items.push(assistant_msg("answer"));
let rollout_items: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::TranscriptItem)
.collect();
let truncated = truncate_before_nth_user_message(InitialHistory::Forked(rollout_items), 1);
let got_items = truncated.get_rollout_items();
let expected: Vec<RolloutItem> = vec![
RolloutItem::TranscriptItem(items[0].clone()),
RolloutItem::TranscriptItem(items[1].clone()),
RolloutItem::TranscriptItem(items[2].clone()),
RolloutItem::TranscriptItem(items[3].clone()),
];
assert_eq!(
serde_json::to_value(&got_items).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}
}