use std::sync::Arc;
use async_trait::async_trait;
use rig_compose::delegate::{DelegateExecutor, DelegateTool};
use rig_compose::registry::{KernelError, ToolRegistry};
use rig_compose::tool::ToolSchema;
use rig_core::completion::CompletionModel;
use serde_json::{Value, json};
use crate::agent::DciAgent;
use crate::session::{DEFAULT_SESSION_ID, InvestigationTurn, SessionConfig, SessionStore};
pub const DEFAULT_TOOL_NAME: &str = "dci_investigate";
const MAX_HISTORY_TURNS_IN_PROMPT: usize = 6;
pub struct DciDelegate<M: CompletionModel> {
agent: Arc<DciAgent<M>>,
sessions: SessionStore,
}
impl<M: CompletionModel + 'static> DciDelegate<M> {
pub fn new(agent: Arc<DciAgent<M>>) -> Self {
Self {
agent,
sessions: SessionStore::new(),
}
}
pub fn with_store(agent: Arc<DciAgent<M>>, sessions: SessionStore) -> Self {
Self { agent, sessions }
}
pub fn sessions(&self) -> &SessionStore {
&self.sessions
}
}
#[async_trait]
impl<M: CompletionModel + 'static> DelegateExecutor for DciDelegate<M> {
async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
let InvokeArgs {
question,
session_id,
} = parse_invoke_args(&args)?;
let history = self
.sessions
.recent_history(&session_id, MAX_HISTORY_TURNS_IN_PROMPT);
let prompt = build_continuation_prompt(&history, &question);
let answer = crate::telemetry::with_session(session_id.clone(), || async {
self.agent
.investigate(&prompt)
.await
.map_err(|e| KernelError::ToolFailed(format!("investigation failed: {e}")))
})
.await?;
let (turn_index, turn_count) = self.sessions.record(&session_id, &question, &answer);
Ok(json!({
"session_id": session_id,
"question": question,
"answer": answer,
"turn_index": turn_index,
"turn_count": turn_count,
}))
}
}
#[derive(Debug)]
struct InvokeArgs {
question: String,
session_id: String,
}
fn parse_invoke_args(args: &Value) -> Result<InvokeArgs, KernelError> {
let question = args
.get("question")
.and_then(Value::as_str)
.map(str::trim)
.filter(|q| !q.is_empty())
.ok_or_else(|| {
KernelError::InvalidArgument(
"dci_investigate requires a non-empty string 'question' field".to_string(),
)
})?
.to_string();
let session_id = args
.get("session_id")
.and_then(Value::as_str)
.map(str::trim)
.filter(|s| !s.is_empty())
.unwrap_or(DEFAULT_SESSION_ID)
.to_string();
Ok(InvokeArgs {
question,
session_id,
})
}
fn build_continuation_prompt(history: &[InvestigationTurn], question: &str) -> String {
if history.is_empty() {
return question.to_string();
}
let start = history.len().saturating_sub(MAX_HISTORY_TURNS_IN_PROMPT);
let shown = history.len() - start;
let mut prompt = format!(
"You are continuing an existing investigation of this corpus. The {shown} most \
recent exchange(s), oldest first:\n",
);
for (offset, turn) in history.iter().skip(start).enumerate() {
prompt.push_str(&format!(
"\n[recent exchange {}/{}]\n Q: {}\n A: {}\n",
offset + 1,
shown,
turn.question,
turn.answer
));
}
prompt.push_str("\nBuild on the evidence already gathered above. New question:\n");
prompt.push_str(question);
prompt
}
fn tool_schema(name: &str) -> ToolSchema {
ToolSchema {
name: name.to_string(),
description: "Investigate a corpus by direct interaction — searching, finding, and \
reading the raw files (no vector database). Provide a natural-language \
'question'; pass a stable 'session_id' to continue a prior investigation \
and build on its evidence."
.to_string(),
args_schema: json!({
"type": "object",
"required": ["question"],
"properties": {
"question": {
"type": "string",
"description": "The question to answer over the corpus."
},
"session_id": {
"type": "string",
"description": "Optional id to continue an existing investigation; omit for a one-off."
}
}
}),
result_schema: json!({
"type": "object",
"properties": {
"session_id": { "type": "string" },
"question": { "type": "string" },
"answer": { "type": "string" },
"turn_index": { "type": "integer" },
"turn_count": { "type": "integer" }
}
}),
}
}
pub struct DciMcpService {
registry: ToolRegistry,
sessions: SessionStore,
tool_name: String,
}
impl DciMcpService {
pub fn new<M: CompletionModel + 'static>(agent: DciAgent<M>) -> Self {
Self::with_name(agent, DEFAULT_TOOL_NAME)
}
pub fn new_with_config<M: CompletionModel + 'static>(
agent: DciAgent<M>,
session_config: SessionConfig,
) -> Self {
Self::with_config(agent, DEFAULT_TOOL_NAME, session_config)
}
pub fn with_name<M: CompletionModel + 'static>(
agent: DciAgent<M>,
tool_name: impl Into<String>,
) -> Self {
Self::with_config(agent, tool_name, SessionConfig::default())
}
pub fn with_config<M: CompletionModel + 'static>(
agent: DciAgent<M>,
tool_name: impl Into<String>,
session_config: SessionConfig,
) -> Self {
let tool_name = tool_name.into();
let sessions = SessionStore::with_config(session_config);
let delegate = DciDelegate::with_store(Arc::new(agent), sessions.clone());
let executor: Arc<dyn DelegateExecutor> = Arc::new(delegate);
let tool = DelegateTool::with_schema(tool_schema(&tool_name), executor);
let registry = ToolRegistry::new();
registry.register(Arc::new(tool));
Self {
registry,
sessions,
tool_name,
}
}
pub fn registry(&self) -> &ToolRegistry {
&self.registry
}
pub fn sessions(&self) -> &SessionStore {
&self.sessions
}
pub fn tool_name(&self) -> &str {
&self.tool_name
}
pub async fn serve_stdio(self) -> Result<(), KernelError> {
rig_mcp::serve_stdio(self.registry).await
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic
)]
use super::*;
use std::time::SystemTime;
fn turn(q: &str, a: &str) -> InvestigationTurn {
InvestigationTurn {
question: q.to_string(),
answer: a.to_string(),
at: SystemTime::now(),
}
}
#[test]
fn parse_args_requires_question() {
let err = parse_invoke_args(&json!({ "session_id": "s" })).unwrap_err();
assert!(matches!(err, KernelError::InvalidArgument(_)));
let err = parse_invoke_args(&json!({ "question": " " })).unwrap_err();
assert!(matches!(err, KernelError::InvalidArgument(_)));
}
#[test]
fn parse_args_defaults_session() {
let parsed = parse_invoke_args(&json!({ "question": "who logged in?" })).unwrap();
assert_eq!(parsed.question, "who logged in?");
assert_eq!(parsed.session_id, DEFAULT_SESSION_ID);
let parsed =
parse_invoke_args(&json!({ "question": "q", "session_id": "case-7" })).unwrap();
assert_eq!(parsed.session_id, "case-7");
}
#[test]
fn first_turn_prompt_is_verbatim() {
let prompt = build_continuation_prompt(&[], "find the brute force source");
assert_eq!(prompt, "find the brute force source");
}
#[test]
fn continuation_prompt_includes_prior_turns() {
let history = vec![turn("who logged in?", "alice from 10.0.0.5")];
let prompt = build_continuation_prompt(&history, "and when?");
assert!(prompt.contains("continuing an existing investigation"));
assert!(prompt.contains("alice from 10.0.0.5"));
assert!(prompt.ends_with("and when?"));
}
#[test]
fn continuation_prompt_caps_history() {
let history: Vec<_> = (0..20)
.map(|i| turn(&format!("q{i}"), &format!("a{i}")))
.collect();
let prompt = build_continuation_prompt(&history, "next");
assert!(prompt.contains("a19"));
assert!(!prompt.contains("a0\n"));
assert!(prompt.matches("[recent exchange ").count() == MAX_HISTORY_TURNS_IN_PROMPT);
}
#[test]
fn continuation_prompt_numbers_relative_to_window() {
let history: Vec<_> = (0..20)
.map(|i| turn(&format!("q{i}"), &format!("a{i}")))
.collect();
let prompt = build_continuation_prompt(&history, "next");
assert!(prompt.contains("[recent exchange 1/6]"));
assert!(prompt.contains("[recent exchange 6/6]"));
assert!(!prompt.contains("turn 15"));
}
#[test]
fn schema_advertises_required_question() {
let schema = tool_schema(DEFAULT_TOOL_NAME);
assert_eq!(schema.name, DEFAULT_TOOL_NAME);
let req_array = schema
.args_schema
.get("required")
.and_then(|v| v.as_array())
.unwrap();
assert_eq!(req_array.first().unwrap().as_str().unwrap(), "question");
}
}