dci-tool 0.1.0

Direct Corpus Interaction: a sandboxed, ripgrep-backed corpus-search toolset and agent for cyber-focused LLM agents, built on rig.
Documentation
//! Model Context Protocol exposure of the DCI agent.
//!
//! Phase 2 wraps the whole [`DciAgent`] as a single, stateful MCP tool:
//!
//! * the rig-core agent is wrapped in a [`rig_compose`] [`DelegateExecutor`],
//! * exposed as a [`DelegateTool`] in a [`ToolRegistry`],
//! * and served over stdio with [`rig_mcp::serve_stdio`].
//!
//! Because MCP treats every tool call as independent, continuity is provided by
//! a [`SessionStore`]: callers pass a `session_id` to keep building on an
//! existing investigation. Prior question/answer turns for that session are
//! folded back into the prompt so the agent can reason from accumulated
//! evidence rather than starting cold.
//!
//! This module is gated behind the `mcp` feature.

use std::sync::Arc;

use async_trait::async_trait;
use rig_compose::delegate::{DelegateExecutor, DelegateTool};
use rig_compose::registry::{KernelError, ToolRegistry};
use rig_compose::tool::ToolSchema;
use rig_core::completion::CompletionModel;
use serde_json::{Value, json};

use crate::agent::DciAgent;
use crate::session::{DEFAULT_SESSION_ID, InvestigationTurn, SessionConfig, SessionStore};

/// The default name of the MCP tool exposed by [`DciMcpService`].
pub const DEFAULT_TOOL_NAME: &str = "dci_investigate";

/// Maximum number of prior turns folded into a continuation prompt.
const MAX_HISTORY_TURNS_IN_PROMPT: usize = 6;

/// A [`DelegateExecutor`] that runs a [`DciAgent`] and tracks per-session state.
pub struct DciDelegate<M: CompletionModel> {
    agent: Arc<DciAgent<M>>,
    sessions: SessionStore,
}

impl<M: CompletionModel + 'static> DciDelegate<M> {
    /// Wrap `agent` with a fresh session store.
    pub fn new(agent: Arc<DciAgent<M>>) -> Self {
        Self {
            agent,
            sessions: SessionStore::new(),
        }
    }

    /// Wrap `agent`, sharing an external session store (e.g. for inspection).
    pub fn with_store(agent: Arc<DciAgent<M>>, sessions: SessionStore) -> Self {
        Self { agent, sessions }
    }

    /// The session store backing this delegate.
    pub fn sessions(&self) -> &SessionStore {
        &self.sessions
    }
}

#[async_trait]
impl<M: CompletionModel + 'static> DelegateExecutor for DciDelegate<M> {
    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
        let InvokeArgs {
            question,
            session_id,
        } = parse_invoke_args(&args)?;

        let history = self
            .sessions
            .recent_history(&session_id, MAX_HISTORY_TURNS_IN_PROMPT);
        let prompt = build_continuation_prompt(&history, &question);

        let answer = crate::telemetry::with_session(session_id.clone(), || async {
            self.agent
                .investigate(&prompt)
                .await
                .map_err(|e| KernelError::ToolFailed(format!("investigation failed: {e}")))
        })
        .await?;

        let (turn_index, turn_count) = self.sessions.record(&session_id, &question, &answer);

        Ok(json!({
            "session_id": session_id,
            "question": question,
            "answer": answer,
            "turn_index": turn_index,
            "turn_count": turn_count,
        }))
    }
}

/// Parsed arguments for an `dci_investigate` call.
#[derive(Debug)]
struct InvokeArgs {
    question: String,
    session_id: String,
}

/// Extract and validate the `question` (required) and `session_id` (optional)
/// fields from a tool-call payload.
fn parse_invoke_args(args: &Value) -> Result<InvokeArgs, KernelError> {
    let question = args
        .get("question")
        .and_then(Value::as_str)
        .map(str::trim)
        .filter(|q| !q.is_empty())
        .ok_or_else(|| {
            KernelError::InvalidArgument(
                "dci_investigate requires a non-empty string 'question' field".to_string(),
            )
        })?
        .to_string();

    let session_id = args
        .get("session_id")
        .and_then(Value::as_str)
        .map(str::trim)
        .filter(|s| !s.is_empty())
        .unwrap_or(DEFAULT_SESSION_ID)
        .to_string();

    Ok(InvokeArgs {
        question,
        session_id,
    })
}

/// Build the prompt for a call, folding recent prior turns in as context when
/// continuing an existing session.
///
/// `history` is expected to already be bounded to the most recent turns (the
/// caller uses [`SessionStore::recent_history`]); turns are labelled relative to
/// this window ("most recent exchanges") rather than with absolute turn numbers,
/// so the framing stays truthful regardless of how long the investigation is.
fn build_continuation_prompt(history: &[InvestigationTurn], question: &str) -> String {
    if history.is_empty() {
        return question.to_string();
    }

    // Defensively bound the window even if a caller passes the full history.
    let start = history.len().saturating_sub(MAX_HISTORY_TURNS_IN_PROMPT);
    let shown = history.len() - start;
    let mut prompt = format!(
        "You are continuing an existing investigation of this corpus. The {shown} most \
         recent exchange(s), oldest first:\n",
    );
    for (offset, turn) in history.iter().skip(start).enumerate() {
        prompt.push_str(&format!(
            "\n[recent exchange {}/{}]\n  Q: {}\n  A: {}\n",
            offset + 1,
            shown,
            turn.question,
            turn.answer
        ));
    }
    prompt.push_str("\nBuild on the evidence already gathered above. New question:\n");
    prompt.push_str(question);
    prompt
}

/// The MCP tool schema advertised to clients.
fn tool_schema(name: &str) -> ToolSchema {
    ToolSchema {
        name: name.to_string(),
        description: "Investigate a corpus by direct interaction — searching, finding, and \
                      reading the raw files (no vector database). Provide a natural-language \
                      'question'; pass a stable 'session_id' to continue a prior investigation \
                      and build on its evidence."
            .to_string(),
        args_schema: json!({
            "type": "object",
            "required": ["question"],
            "properties": {
                "question": {
                    "type": "string",
                    "description": "The question to answer over the corpus."
                },
                "session_id": {
                    "type": "string",
                    "description": "Optional id to continue an existing investigation; omit for a one-off."
                }
            }
        }),
        result_schema: json!({
            "type": "object",
            "properties": {
                "session_id": { "type": "string" },
                "question": { "type": "string" },
                "answer": { "type": "string" },
                "turn_index": { "type": "integer" },
                "turn_count": { "type": "integer" }
            }
        }),
    }
}

/// A ready-to-serve MCP service exposing one [`DciAgent`] as a stateful tool.
pub struct DciMcpService {
    registry: ToolRegistry,
    sessions: SessionStore,
    tool_name: String,
}

impl DciMcpService {
    /// Build a service exposing `agent` under [`DEFAULT_TOOL_NAME`].
    pub fn new<M: CompletionModel + 'static>(agent: DciAgent<M>) -> Self {
        Self::with_name(agent, DEFAULT_TOOL_NAME)
    }

    /// Build a service exposing `agent` under [`DEFAULT_TOOL_NAME`] with explicit
    /// session eviction bounds (see [`SessionConfig`]).
    pub fn new_with_config<M: CompletionModel + 'static>(
        agent: DciAgent<M>,
        session_config: SessionConfig,
    ) -> Self {
        Self::with_config(agent, DEFAULT_TOOL_NAME, session_config)
    }

    /// Build a service exposing `agent` under a custom tool name.
    pub fn with_name<M: CompletionModel + 'static>(
        agent: DciAgent<M>,
        tool_name: impl Into<String>,
    ) -> Self {
        Self::with_config(agent, tool_name, SessionConfig::default())
    }

    /// Build a service exposing `agent` under a custom tool name, with explicit
    /// session eviction bounds (see [`SessionConfig`]).
    pub fn with_config<M: CompletionModel + 'static>(
        agent: DciAgent<M>,
        tool_name: impl Into<String>,
        session_config: SessionConfig,
    ) -> Self {
        let tool_name = tool_name.into();
        let sessions = SessionStore::with_config(session_config);
        let delegate = DciDelegate::with_store(Arc::new(agent), sessions.clone());
        let executor: Arc<dyn DelegateExecutor> = Arc::new(delegate);
        let tool = DelegateTool::with_schema(tool_schema(&tool_name), executor);

        let registry = ToolRegistry::new();
        registry.register(Arc::new(tool));

        Self {
            registry,
            sessions,
            tool_name,
        }
    }

    /// The tool registry, e.g. to register additional tools before serving.
    pub fn registry(&self) -> &ToolRegistry {
        &self.registry
    }

    /// The shared session store, for inspecting investigation history.
    pub fn sessions(&self) -> &SessionStore {
        &self.sessions
    }

    /// The name under which the DCI tool is registered.
    pub fn tool_name(&self) -> &str {
        &self.tool_name
    }

    /// Serve the service over stdio (JSON-RPC on stdin/stdout) until the peer
    /// disconnects. Blocks for the lifetime of the connection.
    pub async fn serve_stdio(self) -> Result<(), KernelError> {
        rig_mcp::serve_stdio(self.registry).await
    }
}

#[cfg(test)]
mod tests {
    #![allow(
        clippy::unwrap_used,
        clippy::expect_used,
        clippy::indexing_slicing,
        clippy::panic
    )]
    use super::*;
    use std::time::SystemTime;

    fn turn(q: &str, a: &str) -> InvestigationTurn {
        InvestigationTurn {
            question: q.to_string(),
            answer: a.to_string(),
            at: SystemTime::now(),
        }
    }

    #[test]
    fn parse_args_requires_question() {
        let err = parse_invoke_args(&json!({ "session_id": "s" })).unwrap_err();
        assert!(matches!(err, KernelError::InvalidArgument(_)));

        let err = parse_invoke_args(&json!({ "question": "   " })).unwrap_err();
        assert!(matches!(err, KernelError::InvalidArgument(_)));
    }

    #[test]
    fn parse_args_defaults_session() {
        let parsed = parse_invoke_args(&json!({ "question": "who logged in?" })).unwrap();
        assert_eq!(parsed.question, "who logged in?");
        assert_eq!(parsed.session_id, DEFAULT_SESSION_ID);

        let parsed =
            parse_invoke_args(&json!({ "question": "q", "session_id": "case-7" })).unwrap();
        assert_eq!(parsed.session_id, "case-7");
    }

    #[test]
    fn first_turn_prompt_is_verbatim() {
        let prompt = build_continuation_prompt(&[], "find the brute force source");
        assert_eq!(prompt, "find the brute force source");
    }

    #[test]
    fn continuation_prompt_includes_prior_turns() {
        let history = vec![turn("who logged in?", "alice from 10.0.0.5")];
        let prompt = build_continuation_prompt(&history, "and when?");
        assert!(prompt.contains("continuing an existing investigation"));
        assert!(prompt.contains("alice from 10.0.0.5"));
        assert!(prompt.ends_with("and when?"));
    }

    #[test]
    fn continuation_prompt_caps_history() {
        let history: Vec<_> = (0..20)
            .map(|i| turn(&format!("q{i}"), &format!("a{i}")))
            .collect();
        let prompt = build_continuation_prompt(&history, "next");
        // Only the last MAX_HISTORY_TURNS_IN_PROMPT turns are included.
        assert!(prompt.contains("a19"));
        assert!(!prompt.contains("a0\n"));
        assert!(prompt.matches("[recent exchange ").count() == MAX_HISTORY_TURNS_IN_PROMPT);
    }

    #[test]
    fn continuation_prompt_numbers_relative_to_window() {
        // Whether or not the caller pre-truncates, the rendered labels number
        // turns relative to the shown window (1/N..N/N) and never claim absolute
        // turn numbers that would be wrong for a long investigation.
        let history: Vec<_> = (0..20)
            .map(|i| turn(&format!("q{i}"), &format!("a{i}")))
            .collect();
        let prompt = build_continuation_prompt(&history, "next");
        assert!(prompt.contains("[recent exchange 1/6]"));
        assert!(prompt.contains("[recent exchange 6/6]"));
        // No absolute "turn 15" style label leaks through.
        assert!(!prompt.contains("turn 15"));
    }

    #[test]
    fn schema_advertises_required_question() {
        let schema = tool_schema(DEFAULT_TOOL_NAME);
        assert_eq!(schema.name, DEFAULT_TOOL_NAME);
        let req_array = schema
            .args_schema
            .get("required")
            .and_then(|v| v.as_array())
            .unwrap();
        assert_eq!(req_array.first().unwrap().as_str().unwrap(), "question");
    }
}