Skip to main content

nika_engine/runtime/executor/
mod.rs

1//! Task Executor - individual task execution
2//!
3//! Handles execution of individual tasks: infer, exec, fetch, invoke, agent.
4//! Uses DashMap for lock-free provider caching.
5//!
6//! ## Module Organization
7//! - `mod.rs`: TaskExecutor struct, constructors, dispatch, shared helpers
8//! - `verbs.rs`: Shared helper functions (estimate_tokens, coerce_json_types, etc.)
9//! - `infer.rs`: `run_infer` + `run_infer_vision` + guardrails
10//! - `exec.rs`: `run_exec` (shell command execution)
11//! - `fetch.rs`: `run_fetch` (HTTP requests)
12//! - `invoke.rs`: `run_invoke` (MCP tool calls / resource reads)
13//! - `agent.rs`: `run_agent` (multi-turn agentic loops)
14//! - `decompose.rs`: Decompose expansion strategies (semantic, static, nested)
15
16mod agent;
17mod decompose;
18mod exec;
19mod extract;
20mod fetch;
21mod infer;
22mod invoke;
23#[cfg(test)]
24mod tests;
25#[cfg(test)]
26mod tests_extract_e2e;
27#[cfg(test)]
28mod tests_extraction_e2e;
29#[cfg(test)]
30mod tests_wiremock;
31mod verbs;
32
33use parking_lot::RwLock;
34use rustc_hash::FxHashMap;
35use std::sync::Arc;
36
37use dashmap::DashMap;
38use tokio_util::sync::CancellationToken;
39use tracing::{debug, instrument};
40
41use crate::ast::output::{OutputFormat, OutputPolicy, SchemaRef};
42use crate::ast::{McpConfigInline, TaskAction};
43use crate::binding::ResolvedBindings;
44use crate::error::NikaError;
45use crate::event::{EventKind, EventLog};
46use crate::mcp::{McpClient, McpClientPool};
47use crate::media::CasStore;
48use crate::provider::rig::RigProvider;
49use crate::runtime::boot::PolicyConfig;
50use crate::runtime::builtin::media::context::MediaToolContext;
51use crate::runtime::policy::PolicyEnforcer;
52use crate::runtime::BuiltinToolRouter;
53use crate::runtime::SkillInjector;
54use crate::store::RunContext;
55use crate::tools::{PermissionMode, ToolContext};
56use crate::util::{CONNECT_TIMEOUT, FETCH_TIMEOUT, REDIRECT_LIMIT};
57
58/// Task executor with cached providers, shared HTTP client, and event logging
59#[derive(Clone)]
60pub struct TaskExecutor {
61    /// Shared HTTP client (connection pooling)
62    http_client: reqwest::Client,
63    /// Cached rig-core providers
64    rig_provider_cache: Arc<DashMap<String, RigProvider>>,
65    /// Centralized MCP client pool
66    ///
67    /// Replaces the previous `mcp_client_cache` + `mcp_configs` pair.
68    /// Handles lazy initialization, per-server deduplication via DashMap + OnceCell,
69    /// and graceful shutdown. Shared across TaskExecutor, TUI App, and ChatAgent.
70    mcp_pool: McpClientPool,
71    /// Default provider name
72    default_provider: Arc<str>,
73    /// Default model
74    default_model: Option<Arc<str>>,
75    /// Event log for fine-grained audit trail
76    event_log: EventLog,
77    /// Router for builtin nika:* tools
78    builtin_router: Arc<BuiltinToolRouter>,
79    /// Policy enforcer for security checks
80    policy_enforcer: Arc<parking_lot::RwLock<PolicyEnforcer>>,
81    /// Cancellation token for aborting in-flight operations
82    ///
83    /// When cancelled, MCP invoke operations race against this token
84    /// so they can abort promptly instead of waiting for INVOKE_TASK_DEADLINE.
85    cancel_token: CancellationToken,
86    /// CAS store for reading media blobs (used by vision content resolution)
87    cas: Arc<CasStore>,
88    /// Tool context for setting permission mode after construction
89    tool_ctx: Arc<ToolContext>,
90    /// Shared SkillInjector for loading and caching skill files
91    skill_injector: Arc<SkillInjector>,
92    /// Workflow-level skills mapping (alias -> file path)
93    skills_map: std::collections::HashMap<String, String>,
94    /// Base directory for resolving relative skill paths
95    workflow_base_dir: std::path::PathBuf,
96}
97
98impl TaskExecutor {
99    /// Create a new executor with default provider, model, MCP configs, and event log
100    pub fn new(
101        provider: &str,
102        model: Option<&str>,
103        mcp_configs: Option<FxHashMap<String, McpConfigInline>>,
104        event_log: EventLog,
105    ) -> Result<Self, NikaError> {
106        Self::with_policy(provider, model, mcp_configs, event_log, None, None)
107    }
108
109    /// Create a new executor with explicit policy configuration.
110    ///
111    /// Returns an error if the media compute pool cannot be created.
112    pub fn with_policy(
113        provider: &str,
114        model: Option<&str>,
115        mcp_configs: Option<FxHashMap<String, McpConfigInline>>,
116        event_log: EventLog,
117        policy_config: Option<PolicyConfig>,
118        permission_mode: Option<PermissionMode>,
119    ) -> Result<Self, NikaError> {
120        // SAFETY: ClientBuilder::build() only fails with custom TLS or proxy config.
121        // We use defaults, so this is effectively infallible.
122        //
123        // Custom redirect policy: check each hop against SSRF blocklist to prevent
124        // SSRF bypass via HTTP redirect (e.g., external → 169.254.169.254).
125        let ssrf_redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
126            use crate::runtime::policy::is_ssrf_blocked;
127
128            if attempt.previous().len() >= REDIRECT_LIMIT {
129                attempt.stop()
130            } else {
131                let blocked = attempt.url().host_str().and_then(|host| {
132                    let h = host.to_lowercase();
133                    let h_normalized = h.trim_start_matches('[').trim_end_matches(']');
134                    if is_ssrf_blocked(h_normalized) {
135                        Some(h)
136                    } else {
137                        None
138                    }
139                });
140                if let Some(host) = blocked {
141                    attempt.error(std::io::Error::new(
142                        std::io::ErrorKind::PermissionDenied,
143                        format!("SSRF protection: redirect to '{}' blocked", host),
144                    ))
145                } else {
146                    attempt.follow()
147                }
148            }
149        });
150        let http_client = reqwest::Client::builder()
151            .timeout(FETCH_TIMEOUT)
152            .connect_timeout(CONNECT_TIMEOUT)
153            .redirect(ssrf_redirect_policy)
154            .user_agent(format!("nika/{}", env!("CARGO_PKG_VERSION")))
155            .build()
156            .expect("HTTP client build with default TLS is infallible");
157
158        let policy_enforcer = PolicyEnforcer::new(policy_config.unwrap_or_default());
159
160        // Create ToolContext for file tools
161        let working_dir = std::env::current_dir().unwrap_or_else(|_| {
162            tracing::warn!("Failed to get current directory, using /tmp");
163            std::path::PathBuf::from("/tmp")
164        });
165        let perm = permission_mode.unwrap_or(PermissionMode::Plan);
166        tracing::debug!(?perm, "File tools using PermissionMode");
167        let tool_ctx = Arc::new(ToolContext::new(working_dir.clone(), perm));
168
169        // Create media tool context with CAS store at workspace default
170        let media_ctx = Arc::new(MediaToolContext::new(CasStore::workspace_default(
171            &working_dir,
172        ))?);
173        // Separate CAS handle for vision content resolution (same directory)
174        let cas = Arc::new(CasStore::workspace_default(&working_dir));
175
176        Ok(Self {
177            http_client,
178            rig_provider_cache: Arc::new(DashMap::new()),
179            mcp_pool: McpClientPool::with_configs(
180                event_log.clone(),
181                mcp_configs.unwrap_or_default(),
182            ),
183            default_provider: provider.into(),
184            default_model: model.map(Into::into),
185            event_log,
186            builtin_router: Arc::new(BuiltinToolRouter::with_all_tools(
187                tool_ctx.clone(),
188                media_ctx,
189            )),
190            policy_enforcer: Arc::new(RwLock::new(policy_enforcer)),
191            cancel_token: CancellationToken::new(),
192            cas,
193            tool_ctx,
194            skill_injector: Arc::new(SkillInjector::new()),
195            skills_map: std::collections::HashMap::new(),
196            workflow_base_dir: working_dir,
197        })
198    }
199
200    /// Set the permission mode for file tools (nika:write, nika:edit, etc.)
201    pub fn set_permission_mode(&self, mode: PermissionMode) {
202        self.tool_ctx.set_permission_mode(mode);
203    }
204
205    /// Set a cancellation token for aborting in-flight operations.
206    ///
207    /// When the token is cancelled, MCP invoke operations will abort promptly
208    /// instead of waiting for the full INVOKE_TASK_DEADLINE timeout.
209    pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
210        self.cancel_token = token;
211        self
212    }
213
214    /// Check if the executor has been cancelled.
215    pub fn is_cancelled(&self) -> bool {
216        self.cancel_token.is_cancelled()
217    }
218
219    /// Set the workflow-level skills mapping for agent skill injection.
220    ///
221    /// When set, agents with `skills:` configured will have skill content
222    /// loaded and prepended to their system prompts via `SkillInjector`.
223    pub fn with_skills(
224        mut self,
225        skills_map: std::collections::HashMap<String, String>,
226        base_dir: std::path::PathBuf,
227    ) -> Self {
228        self.skills_map = skills_map;
229        self.workflow_base_dir = base_dir;
230        self
231    }
232
233    /// Inject a mock MCP client for testing
234    ///
235    /// This allows tests to use mock clients without relying on automatic fallback.
236    /// Call this after creating the executor but before executing invoke actions.
237    #[cfg(test)]
238    pub fn inject_mock_mcp_client(&self, name: &str) {
239        self.mcp_pool
240            .inject_mock(name, Arc::new(McpClient::mock(name)));
241    }
242
243    /// Build JSON schema instruction for LLM prompts
244    ///
245    /// When output policy requires JSON format with a schema, this generates
246    /// an instruction string to append to the prompt, telling the LLM to
247    /// output valid JSON conforming to the schema.
248    /// Build JSON schema instruction for LLM prompt injection.
249    ///
250    /// `cached_example` is used for file-based `from_example` — the caller pre-reads
251    /// the file asynchronously and passes the parsed value here for synchronous injection.
252    pub(super) fn build_json_schema_instruction(
253        output_policy: Option<&OutputPolicy>,
254        cached_example: Option<&serde_json::Value>,
255    ) -> Option<String> {
256        let policy = output_policy?;
257        if policy.format != OutputFormat::Json {
258            return None;
259        }
260
261        // from_example: inject example structure or generic instruction.
262        match policy.from_example.as_ref() {
263            Some(SchemaRef::Inline(ref example)) => {
264                return Self::format_example_instruction(example);
265            }
266            Some(SchemaRef::File(_)) => {
267                // File-based: use cached_example if pre-loaded, otherwise generic instruction.
268                if let Some(example) = cached_example {
269                    return Self::format_example_instruction(example);
270                }
271                return Some(
272                    "\n\n---\n\
273                     CRITICAL OUTPUT REQUIREMENT:\n\
274                     Your response MUST be valid JSON.\n\n\
275                     Rules:\n\
276                     - Output ONLY the JSON object, no additional text\n\
277                     - Do NOT wrap in markdown code blocks (no ```json)\n\
278                     - Ensure all JSON is properly formatted and valid"
279                        .to_string(),
280                );
281            }
282            None => {} // no from_example — fall through to schema-based injection below
283        }
284
285        let schema_ref = policy.schema.as_ref()?;
286        let schema_json = match schema_ref {
287            SchemaRef::Inline(v) => v.clone(),
288            SchemaRef::File(_) => {
289                return Some(
290                    "\n\n---\n\
291                     CRITICAL OUTPUT REQUIREMENT:\n\
292                     Your response MUST be valid JSON.\n\n\
293                     Rules:\n\
294                     - Output ONLY the JSON object, no additional text\n\
295                     - Do NOT wrap in markdown code blocks (no ```json)\n\
296                     - Ensure all JSON is properly formatted and valid"
297                        .to_string(),
298                );
299            }
300        };
301        let schema_str = serde_json::to_string_pretty(&schema_json).unwrap_or_default();
302        Some(format!(
303            "\n\n---\n\
304             CRITICAL OUTPUT REQUIREMENT:\n\
305             Your response MUST be valid JSON that conforms to this schema:\n\n\
306             ```json\n{}\n```\n\n\
307             Rules:\n\
308             - Output ONLY the JSON object, no additional text before or after\n\
309             - Do NOT wrap your response in markdown code blocks (no ```json)\n\
310             - All required fields must be present\n\
311             - Field types must match the schema exactly",
312            schema_str
313        ))
314    }
315
316    /// Format an example JSON value into a prompt injection instruction.
317    fn format_example_instruction(example: &serde_json::Value) -> Option<String> {
318        let example_str = match serde_json::to_string_pretty(example) {
319            Ok(s) => s,
320            Err(e) => {
321                tracing::warn!(
322                    "Failed to serialize from_example for prompt injection: {}",
323                    e
324                );
325                return None;
326            }
327        };
328        Some(format!(
329            "\n\n---\n\
330             CRITICAL OUTPUT REQUIREMENT:\n\
331             Your response MUST be valid JSON matching this exact structure:\n\n\
332             ```json\n{}\n```\n\n\
333             Rules:\n\
334             - Output ONLY the JSON object, no additional text\n\
335             - Do NOT wrap in markdown code blocks (no ```json)\n\
336             - All keys shown above must be present\n\
337             - Value types must match (strings, numbers, arrays, objects)",
338            example_str
339        ))
340    }
341
342    /// Run a task action with the given bindings
343    ///
344    /// The datastore is required for resolving lazy bindings during template substitution.
345    /// The output_policy is used to inject JSON schema instructions into prompts for infer/agent.
346    #[instrument(skip(self, bindings, datastore, output_policy), fields(action_type = %action_type(action)))]
347    pub async fn execute(
348        &self,
349        task_id: &Arc<str>,
350        action: &TaskAction,
351        bindings: &ResolvedBindings,
352        datastore: &RunContext,
353        output_policy: Option<&OutputPolicy>,
354    ) -> Result<String, NikaError> {
355        debug!("Running task action");
356        match action {
357            TaskAction::Infer { infer } => {
358                self.run_infer(task_id, infer, bindings, datastore, output_policy)
359                    .await
360            }
361            TaskAction::Exec { exec: e } => self.run_exec(task_id, e, bindings, datastore).await,
362            TaskAction::Fetch { fetch } => {
363                self.run_fetch(task_id, fetch, bindings, datastore).await
364            }
365            TaskAction::Invoke { invoke } => {
366                self.run_invoke(task_id, invoke, bindings, datastore).await
367            }
368            TaskAction::Agent { agent } => {
369                self.run_agent(task_id, agent, bindings, datastore, output_policy)
370                    .await
371            }
372        }
373    }
374
375    /// Get or create a cached rig-core provider.
376    ///
377    /// Resolves provider names and aliases via [`RigProvider::from_name()`],
378    /// which uses `core::find_provider()` as the single source of truth.
379    pub(super) fn get_rig_provider(&self, name: &str) -> Result<RigProvider, NikaError> {
380        use dashmap::mapref::entry::Entry;
381
382        // Normalize provider name so aliases ("claude") and canonical ("anthropic")
383        // share the same cache entry, avoiding double-instantiation.
384        let canonical = crate::core::find_provider(name)
385            .map(|p| p.id)
386            .unwrap_or(name);
387
388        match self.rig_provider_cache.entry(canonical.to_string()) {
389            Entry::Occupied(e) => Ok(e.get().clone()),
390            Entry::Vacant(e) => {
391                let provider = RigProvider::from_name(name)?;
392                e.insert(provider.clone());
393                // EMIT: ProviderInitialized (cache miss — first use)
394                self.event_log.emit(EventKind::ProviderInitialized {
395                    provider: canonical.to_string(),
396                    model: provider.default_model().to_string(),
397                    cached: false,
398                });
399                Ok(provider)
400            }
401        }
402    }
403
404    /// Get the default provider name.
405    pub fn default_provider(&self) -> &str {
406        &self.default_provider
407    }
408
409    /// Get or create an MCP client for a named server
410    ///
411    /// Uses OnceCell per server to ensure thread-safe initialization.
412    /// Even with concurrent for_each iterations, only one client is created per server.
413    ///
414    /// Delegates to [`McpClientPool::get_or_connect`] which handles lazy initialization,
415    /// per-server deduplication via DashMap + OnceCell, and event logging.
416    pub(super) async fn get_mcp_client(&self, name: &str) -> Result<Arc<McpClient>, NikaError> {
417        self.mcp_pool.get_or_connect(name).await.map_err(Into::into)
418    }
419
420    /// Gracefully shut down all MCP server connections.
421    ///
422    /// Delegates to [`McpClientPool::shutdown_all`] which terminates server
423    /// processes and marks the pool as shut down. Idempotent.
424    pub async fn shutdown_mcp(&self) {
425        self.mcp_pool.shutdown_all().await;
426    }
427}
428
429/// Get action type as string for tracing
430pub(super) fn action_type(action: &TaskAction) -> &'static str {
431    match action {
432        TaskAction::Infer { .. } => "infer",
433        TaskAction::Exec { .. } => "exec",
434        TaskAction::Fetch { .. } => "fetch",
435        TaskAction::Invoke { .. } => "invoke",
436        TaskAction::Agent { .. } => "agent",
437    }
438}