use std::sync::Arc;
use tempfile::TempDir;
use super::super::mock::{MockChatProvider, MockResolver};
use crate::core::sm::agent::{SessionManagerAgent, SmAgentError};
use crate::core::sm::config::SessionManagerConfig;
fn enabled_config() -> SessionManagerConfig {
SessionManagerConfig {
enabled: true,
..SessionManagerConfig::default()
}
}
fn agent_with(
cfg: SessionManagerConfig,
resolver: Arc<MockResolver>,
data_root: &std::path::Path,
) -> SessionManagerAgent {
#[cfg(feature = "sm-memory")]
{
SessionManagerAgent::with_runtime(cfg, resolver, data_root.to_path_buf(), None)
}
#[cfg(not(feature = "sm-memory"))]
{
SessionManagerAgent::with_runtime(cfg, resolver, data_root.to_path_buf())
}
}
#[tokio::test]
async fn chat_drives_full_turn_with_mock_provider() {
let tmp = TempDir::new().unwrap();
let provider = MockChatProvider::new("here is my plan", 0.0042);
let resolver = Arc::new(MockResolver::with_provider(provider.clone()));
let agent = agent_with(enabled_config(), resolver, tmp.path());
let outcome = agent
.chat("decompose the login feature", Some("conv-1"))
.await
.expect("chat turn succeeds with a provider");
assert_eq!(outcome.reply, "here is my plan");
assert_eq!(outcome.conv_id, "conv-1");
assert!(
(outcome.cost_usd - 0.0042).abs() < 1e-9,
"per-call cost returned"
);
let req = provider.last_request().expect("provider was called");
assert!(
req.system.contains("# Session Manager (SM) -- trusty-mpm"),
"system message must include the SM system prompt"
);
assert!(
req.system.contains("# BASE_SM Framework Floor"),
"system message must include the non-overridable floor"
);
let last = req.messages.last().expect("at least the current message");
assert_eq!(last.role, "user");
assert_eq!(last.content, "decompose the login feature");
assert!((req.temperature - 0.3).abs() < 1e-6);
}
#[tokio::test]
async fn chat_carries_prior_round_into_next_turn() {
let tmp = TempDir::new().unwrap();
let provider = MockChatProvider::new("ack", 0.0);
let resolver = Arc::new(MockResolver::with_provider(provider.clone()));
let agent = agent_with(enabled_config(), resolver, tmp.path());
agent.chat("first message", Some("c")).await.unwrap();
agent.chat("second message", Some("c")).await.unwrap();
let req = provider.last_request().unwrap();
let contents: Vec<&str> = req.messages.iter().map(|m| m.content.as_str()).collect();
assert!(
contents.iter().any(|c| c.contains("first message")),
"the prior round must be carried into the next turn's prompt, got {contents:?}"
);
assert_eq!(req.messages.last().unwrap().content, "second message");
}
#[tokio::test]
async fn chat_records_round_to_persistent_store() {
let tmp = TempDir::new().unwrap();
let provider_a = MockChatProvider::new("a-reply", 0.0);
let agent_a = agent_with(
enabled_config(),
Arc::new(MockResolver::with_provider(provider_a)),
tmp.path(),
);
agent_a
.chat("remember this", Some("persist"))
.await
.unwrap();
let provider_b = MockChatProvider::new("b-reply", 0.0);
let provider_b_handle = provider_b.clone();
let agent_b = agent_with(
enabled_config(),
Arc::new(MockResolver::with_provider(provider_b)),
tmp.path(),
);
agent_b.chat("follow up", Some("persist")).await.unwrap();
let req = provider_b_handle.last_request().unwrap();
let joined: String = req
.messages
.iter()
.map(|m| m.content.clone())
.collect::<Vec<_>>()
.join("\n");
assert!(
joined.contains("remember this"),
"persisted round must be resumed by a fresh agent, got: {joined}"
);
}
#[tokio::test]
async fn chat_without_provider_is_degraded() {
let tmp = TempDir::new().unwrap();
let agent = agent_with(
enabled_config(),
Arc::new(MockResolver::degraded()),
tmp.path(),
);
let err = agent.chat("hello", Some("c")).await.unwrap_err();
match err {
SmAgentError::Degraded(msg) => {
assert!(msg.contains("no inference provider configured"));
}
other => panic!("expected Degraded, got {other:?}"),
}
}
#[tokio::test]
async fn chat_without_runtime_is_degraded() {
let agent = SessionManagerAgent::new(enabled_config());
let err = agent.chat("hello", None).await.unwrap_err();
assert!(matches!(err, SmAgentError::Degraded(_)));
}
#[tokio::test]
async fn chat_resolution_error_is_inference_error() {
let tmp = TempDir::new().unwrap();
let agent = agent_with(
enabled_config(),
Arc::new(MockResolver::validation()),
tmp.path(),
);
let err = agent.chat("hello", Some("c")).await.unwrap_err();
assert!(matches!(err, SmAgentError::Inference(_)));
}
#[tokio::test]
async fn chat_mints_conv_id_when_absent() {
let tmp = TempDir::new().unwrap();
let agent = agent_with(
enabled_config(),
Arc::new(MockResolver::with_provider(MockChatProvider::new(
"ok", 0.0,
))),
tmp.path(),
);
let outcome = agent.chat("hi", None).await.unwrap();
assert!(
!outcome.conv_id.trim().is_empty(),
"a conv_id must be minted"
);
}
#[tokio::test]
async fn chat_works_without_memory_recall() {
let tmp = TempDir::new().unwrap();
let provider = MockChatProvider::new("ok", 0.0);
let resolver = Arc::new(MockResolver::with_provider(provider.clone()));
let agent = agent_with(enabled_config(), resolver, tmp.path());
let outcome = agent.chat("no recall here", Some("c")).await.unwrap();
assert_eq!(outcome.reply, "ok");
let req = provider.last_request().unwrap();
assert!(
!req.system.contains("Relevant memory:"),
"no recall is wired, so no memory block should appear"
);
}
#[tokio::test]
async fn chat_records_round_when_no_provider_for_compaction() {
let tmp = TempDir::new().unwrap();
let provider = MockChatProvider::new("the plan", 0.0);
let resolver = Arc::new(MockResolver::provider_then_degraded(provider, 1));
let agent = agent_with(enabled_config(), resolver, tmp.path());
let outcome = agent
.chat("decompose this", Some("conv-int"))
.await
.expect("turn succeeds: the reply was produced");
assert_eq!(outcome.reply, "the plan");
let provider_b = MockChatProvider::new("ack", 0.0);
let provider_b_handle = provider_b.clone();
let agent_b = agent_with(
enabled_config(),
Arc::new(MockResolver::with_provider(provider_b)),
tmp.path(),
);
agent_b
.chat("next", Some("conv-int"))
.await
.expect("follow-up turn");
let req = provider_b_handle.last_request().expect("provider called");
let joined: String = req
.messages
.iter()
.map(|m| m.content.clone())
.collect::<Vec<_>>()
.join("\n");
assert!(
joined.contains("decompose this"),
"the round must be recorded verbatim even when compaction has no provider, got: {joined}"
);
}
#[cfg(feature = "sm-memory")]
#[tokio::test]
async fn chat_includes_recall_when_memory_present() {
use crate::core::sm::config::SmMemoryConfig;
use crate::core::sm::memory::SmMemory;
use trusty_common::memory_core::retrieval::seed_shared_embedder_with_mock;
seed_shared_embedder_with_mock();
let tmp = TempDir::new().unwrap();
let mem = SmMemory::open(tmp.path().join("palace"), &SmMemoryConfig::default())
.expect("open SM memory");
mem.remember("project trusty-tools requires SKIP_UI_BUILD=1 for cargo publish")
.await
.expect("remember a fact");
let provider = MockChatProvider::new("noted", 0.0);
let resolver = Arc::new(MockResolver::with_provider(provider.clone()));
let agent = SessionManagerAgent::with_runtime(
enabled_config(),
resolver,
tmp.path().to_path_buf(),
Some(mem),
);
agent
.chat("how do I cargo publish trusty-tools?", Some("c"))
.await
.expect("chat with recall succeeds");
let req = provider.last_request().unwrap();
assert!(
req.system.contains("Relevant memory:"),
"a recall block must be injected when memory holds a relevant fact"
);
assert!(
req.system.contains("SKIP_UI_BUILD"),
"the recalled fact content must appear in the working prompt, got: {}",
req.system
);
}