use crate::brain::agent::service::AgentService;
use crate::brain::tools::ToolRegistry;
use crate::db::Database;
use crate::services::{MessageService, ServiceContext, SessionService};
use crate::tests::agent_service_mocks::*;
use std::sync::Arc;
#[tokio::test]
async fn test_concurrent_sessions_independent() {
let db = Database::connect_in_memory().await.unwrap();
db.run_migrations().await.unwrap();
let pool = db.pool().clone();
let context = ServiceContext::new(pool);
let provider = Arc::new(MockProvider);
let agent_service = Arc::new(AgentService::new_for_test(provider, context.clone()).await);
let session_service = SessionService::new(context);
let session_a = session_service
.create_session(Some("Session A".to_string()))
.await
.unwrap();
let session_b = session_service
.create_session(Some("Session B".to_string()))
.await
.unwrap();
let id_a = session_a.id;
let id_b = session_b.id;
let resp_a = agent_service
.send_message(id_a, "Hello from A".to_string(), None)
.await
.unwrap();
let resp_b = agent_service
.send_message(id_b, "Hello from B".to_string(), None)
.await
.unwrap();
assert!(
!resp_a.content.is_empty(),
"session A should have a response"
);
assert!(
!resp_b.content.is_empty(),
"session B should have a response"
);
assert_eq!(resp_a.model, "mock-model");
assert_eq!(resp_b.model, "mock-model");
}
#[tokio::test]
async fn test_concurrent_sessions_different_providers() {
let db = Database::connect_in_memory().await.unwrap();
db.run_migrations().await.unwrap();
let pool = db.pool().clone();
let context = ServiceContext::new(pool);
let provider_alpha = Arc::new(MockProviderWithModel::new("alpha", "alpha-model"));
let provider_beta = Arc::new(MockProviderWithModel::new("beta", "beta-model"));
let svc_alpha = Arc::new(AgentService::new_for_test(provider_alpha, context.clone()).await);
let svc_beta = Arc::new(AgentService::new_for_test(provider_beta, context.clone()).await);
let session_service = SessionService::new(context);
let session_a = session_service
.create_session(Some("Alpha Session".to_string()))
.await
.unwrap();
let session_b = session_service
.create_session(Some("Beta Session".to_string()))
.await
.unwrap();
let id_a = session_a.id;
let id_b = session_b.id;
let resp_a = svc_alpha
.send_message(id_a, "Hello alpha".to_string(), None)
.await
.unwrap();
let resp_b = svc_beta
.send_message(id_b, "Hello beta".to_string(), None)
.await
.unwrap();
assert!(
resp_a.content.contains("alpha"),
"session A should get response from alpha provider, got: {}",
resp_a.content
);
assert!(
resp_b.content.contains("beta"),
"session B should get response from beta provider, got: {}",
resp_b.content
);
}
#[tokio::test]
async fn test_cancel_one_session_other_continues() {
use tokio_util::sync::CancellationToken;
let db = Database::connect_in_memory().await.unwrap();
db.run_migrations().await.unwrap();
let pool = db.pool().clone();
let context = ServiceContext::new(pool);
let provider = Arc::new(MockProviderWithTools::new());
let registry = ToolRegistry::new();
registry.register(Arc::new(MockTool));
let svc_a = Arc::new(
AgentService::new_for_test(Arc::new(MockProviderWithTools::new()), context.clone())
.await
.with_tool_registry(Arc::new({
let r = ToolRegistry::new();
r.register(Arc::new(MockTool));
r
}))
.with_auto_approve_tools(true),
);
let svc_b = Arc::new(
AgentService::new_for_test(provider, context.clone())
.await
.with_tool_registry(Arc::new(registry))
.with_auto_approve_tools(true),
);
let session_service = SessionService::new(context);
let session_a = session_service
.create_session(Some("Session A".to_string()))
.await
.unwrap();
let session_b = session_service
.create_session(Some("Session B".to_string()))
.await
.unwrap();
let cancel_a = CancellationToken::new();
cancel_a.cancel();
let id_a = session_a.id;
let id_b = session_b.id;
let result_a = svc_a
.send_message_with_tools_and_mode(
id_a,
"Use the tool".to_string(),
None,
Some(cancel_a.clone()),
)
.await;
let result_b = svc_b
.send_message_with_tools(id_b, "Use the tool".to_string(), None)
.await;
let _ = result_a;
let resp_b = result_b.unwrap();
assert!(
!resp_b.content.is_empty(),
"session B should complete normally"
);
}
#[tokio::test]
async fn test_message_isolation_between_sessions() {
let db = Database::connect_in_memory().await.unwrap();
db.run_migrations().await.unwrap();
let pool = db.pool().clone();
let context = ServiceContext::new(pool);
let provider = Arc::new(MockProvider);
let agent_service = Arc::new(AgentService::new_for_test(provider, context.clone()).await);
let session_service = SessionService::new(context.clone());
let session_a = session_service
.create_session(Some("Session A".to_string()))
.await
.unwrap();
let session_b = session_service
.create_session(Some("Session B".to_string()))
.await
.unwrap();
agent_service
.send_message(session_a.id, "Message only in session A".to_string(), None)
.await
.unwrap();
agent_service
.send_message(session_b.id, "Message only in session B".to_string(), None)
.await
.unwrap();
let message_service = MessageService::new(context);
let msgs_a = message_service
.list_messages_for_session(session_a.id)
.await
.unwrap();
let msgs_b = message_service
.list_messages_for_session(session_b.id)
.await
.unwrap();
let text_a: Vec<&str> = msgs_a.iter().map(|m| m.content.as_str()).collect();
let text_b: Vec<&str> = msgs_b.iter().map(|m| m.content.as_str()).collect();
assert!(
text_a.iter().any(|t| t.contains("only in session A")),
"session A messages should contain session-A message"
);
assert!(
!text_a.iter().any(|t| t.contains("only in session B")),
"session A messages should NOT contain session-B message"
);
assert!(
text_b.iter().any(|t| t.contains("only in session B")),
"session B messages should contain session-B message"
);
assert!(
!text_b.iter().any(|t| t.contains("only in session A")),
"session B messages should NOT contain session-A message"
);
}
#[tokio::test]
async fn test_session_usage_tracked_independently() {
let db = Database::connect_in_memory().await.unwrap();
db.run_migrations().await.unwrap();
let pool = db.pool().clone();
let context = ServiceContext::new(pool);
let provider = Arc::new(MockProvider);
let agent_service = Arc::new(AgentService::new_for_test(provider, context.clone()).await);
let session_service = SessionService::new(context.clone());
let session_a = session_service
.create_session(Some("Session A".to_string()))
.await
.unwrap();
let session_b = session_service
.create_session(Some("Session B".to_string()))
.await
.unwrap();
agent_service
.send_message(session_a.id, "First message to A".to_string(), None)
.await
.unwrap();
agent_service
.send_message(session_a.id, "Second message to A".to_string(), None)
.await
.unwrap();
agent_service
.send_message(session_b.id, "Only message to B".to_string(), None)
.await
.unwrap();
let updated_a = session_service
.get_session_required(session_a.id)
.await
.unwrap();
let updated_b = session_service
.get_session_required(session_b.id)
.await
.unwrap();
assert_eq!(
updated_a.token_count, 60,
"session A should have 60 tokens (2 messages × 30)"
);
assert_eq!(
updated_b.token_count, 30,
"session B should have 30 tokens (1 message × 30)"
);
assert!(
updated_a.total_cost > updated_b.total_cost,
"session A (2 msgs) should have higher cost than session B (1 msg)"
);
}