nika_engine/runtime/executor/
mod.rs1mod agent;
17mod decompose;
18mod exec;
19mod extract;
20mod fetch;
21mod infer;
22mod invoke;
23#[cfg(test)]
24mod tests;
25#[cfg(test)]
26mod tests_extract_e2e;
27#[cfg(test)]
28mod tests_extraction_e2e;
29#[cfg(test)]
30mod tests_wiremock;
31mod verbs;
32
33use parking_lot::RwLock;
34use rustc_hash::FxHashMap;
35use std::sync::Arc;
36
37use dashmap::DashMap;
38use tokio_util::sync::CancellationToken;
39use tracing::{debug, instrument};
40
41use crate::ast::output::{OutputFormat, OutputPolicy, SchemaRef};
42use crate::ast::{McpConfigInline, TaskAction};
43use crate::binding::ResolvedBindings;
44use crate::error::NikaError;
45use crate::event::{EventKind, EventLog};
46use crate::mcp::{McpClient, McpClientPool};
47use crate::media::CasStore;
48use crate::provider::rig::RigProvider;
49use crate::runtime::boot::PolicyConfig;
50use crate::runtime::builtin::media::context::MediaToolContext;
51use crate::runtime::policy::PolicyEnforcer;
52use crate::runtime::BuiltinToolRouter;
53use crate::runtime::SkillInjector;
54use crate::store::RunContext;
55use crate::tools::{PermissionMode, ToolContext};
56use crate::util::{CONNECT_TIMEOUT, FETCH_TIMEOUT, REDIRECT_LIMIT};
57
58#[derive(Clone)]
60pub struct TaskExecutor {
61 http_client: reqwest::Client,
63 rig_provider_cache: Arc<DashMap<String, RigProvider>>,
65 mcp_pool: McpClientPool,
71 default_provider: Arc<str>,
73 default_model: Option<Arc<str>>,
75 event_log: EventLog,
77 builtin_router: Arc<BuiltinToolRouter>,
79 policy_enforcer: Arc<parking_lot::RwLock<PolicyEnforcer>>,
81 cancel_token: CancellationToken,
86 cas: Arc<CasStore>,
88 tool_ctx: Arc<ToolContext>,
90 skill_injector: Arc<SkillInjector>,
92 skills_map: std::collections::HashMap<String, String>,
94 workflow_base_dir: std::path::PathBuf,
96}
97
98impl TaskExecutor {
99 pub fn new(
101 provider: &str,
102 model: Option<&str>,
103 mcp_configs: Option<FxHashMap<String, McpConfigInline>>,
104 event_log: EventLog,
105 ) -> Result<Self, NikaError> {
106 Self::with_policy(provider, model, mcp_configs, event_log, None, None)
107 }
108
109 pub fn with_policy(
113 provider: &str,
114 model: Option<&str>,
115 mcp_configs: Option<FxHashMap<String, McpConfigInline>>,
116 event_log: EventLog,
117 policy_config: Option<PolicyConfig>,
118 permission_mode: Option<PermissionMode>,
119 ) -> Result<Self, NikaError> {
120 let ssrf_redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
126 use crate::runtime::policy::is_ssrf_blocked;
127
128 if attempt.previous().len() >= REDIRECT_LIMIT {
129 attempt.stop()
130 } else {
131 let blocked = attempt.url().host_str().and_then(|host| {
132 let h = host.to_lowercase();
133 let h_normalized = h.trim_start_matches('[').trim_end_matches(']');
134 if is_ssrf_blocked(h_normalized) {
135 Some(h)
136 } else {
137 None
138 }
139 });
140 if let Some(host) = blocked {
141 attempt.error(std::io::Error::new(
142 std::io::ErrorKind::PermissionDenied,
143 format!("SSRF protection: redirect to '{}' blocked", host),
144 ))
145 } else {
146 attempt.follow()
147 }
148 }
149 });
150 let http_client = reqwest::Client::builder()
151 .timeout(FETCH_TIMEOUT)
152 .connect_timeout(CONNECT_TIMEOUT)
153 .redirect(ssrf_redirect_policy)
154 .user_agent(format!("nika/{}", env!("CARGO_PKG_VERSION")))
155 .build()
156 .expect("HTTP client build with default TLS is infallible");
157
158 let policy_enforcer = PolicyEnforcer::new(policy_config.unwrap_or_default());
159
160 let working_dir = std::env::current_dir().unwrap_or_else(|_| {
162 tracing::warn!("Failed to get current directory, using /tmp");
163 std::path::PathBuf::from("/tmp")
164 });
165 let perm = permission_mode.unwrap_or(PermissionMode::Plan);
166 tracing::debug!(?perm, "File tools using PermissionMode");
167 let tool_ctx = Arc::new(ToolContext::new(working_dir.clone(), perm));
168
169 let media_ctx = Arc::new(MediaToolContext::new(CasStore::workspace_default(
171 &working_dir,
172 ))?);
173 let cas = Arc::new(CasStore::workspace_default(&working_dir));
175
176 Ok(Self {
177 http_client,
178 rig_provider_cache: Arc::new(DashMap::new()),
179 mcp_pool: McpClientPool::with_configs(
180 event_log.clone(),
181 mcp_configs.unwrap_or_default(),
182 ),
183 default_provider: provider.into(),
184 default_model: model.map(Into::into),
185 event_log,
186 builtin_router: Arc::new(BuiltinToolRouter::with_all_tools(
187 tool_ctx.clone(),
188 media_ctx,
189 )),
190 policy_enforcer: Arc::new(RwLock::new(policy_enforcer)),
191 cancel_token: CancellationToken::new(),
192 cas,
193 tool_ctx,
194 skill_injector: Arc::new(SkillInjector::new()),
195 skills_map: std::collections::HashMap::new(),
196 workflow_base_dir: working_dir,
197 })
198 }
199
200 pub fn set_permission_mode(&self, mode: PermissionMode) {
202 self.tool_ctx.set_permission_mode(mode);
203 }
204
205 pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
210 self.cancel_token = token;
211 self
212 }
213
214 pub fn is_cancelled(&self) -> bool {
216 self.cancel_token.is_cancelled()
217 }
218
219 pub fn with_skills(
224 mut self,
225 skills_map: std::collections::HashMap<String, String>,
226 base_dir: std::path::PathBuf,
227 ) -> Self {
228 self.skills_map = skills_map;
229 self.workflow_base_dir = base_dir;
230 self
231 }
232
233 #[cfg(test)]
238 pub fn inject_mock_mcp_client(&self, name: &str) {
239 self.mcp_pool
240 .inject_mock(name, Arc::new(McpClient::mock(name)));
241 }
242
243 pub(super) fn build_json_schema_instruction(
253 output_policy: Option<&OutputPolicy>,
254 cached_example: Option<&serde_json::Value>,
255 ) -> Option<String> {
256 let policy = output_policy?;
257 if policy.format != OutputFormat::Json {
258 return None;
259 }
260
261 match policy.from_example.as_ref() {
263 Some(SchemaRef::Inline(ref example)) => {
264 return Self::format_example_instruction(example);
265 }
266 Some(SchemaRef::File(_)) => {
267 if let Some(example) = cached_example {
269 return Self::format_example_instruction(example);
270 }
271 return Some(
272 "\n\n---\n\
273 CRITICAL OUTPUT REQUIREMENT:\n\
274 Your response MUST be valid JSON.\n\n\
275 Rules:\n\
276 - Output ONLY the JSON object, no additional text\n\
277 - Do NOT wrap in markdown code blocks (no ```json)\n\
278 - Ensure all JSON is properly formatted and valid"
279 .to_string(),
280 );
281 }
282 None => {} }
284
285 let schema_ref = policy.schema.as_ref()?;
286 let schema_json = match schema_ref {
287 SchemaRef::Inline(v) => v.clone(),
288 SchemaRef::File(_) => {
289 return Some(
290 "\n\n---\n\
291 CRITICAL OUTPUT REQUIREMENT:\n\
292 Your response MUST be valid JSON.\n\n\
293 Rules:\n\
294 - Output ONLY the JSON object, no additional text\n\
295 - Do NOT wrap in markdown code blocks (no ```json)\n\
296 - Ensure all JSON is properly formatted and valid"
297 .to_string(),
298 );
299 }
300 };
301 let schema_str = serde_json::to_string_pretty(&schema_json).unwrap_or_default();
302 Some(format!(
303 "\n\n---\n\
304 CRITICAL OUTPUT REQUIREMENT:\n\
305 Your response MUST be valid JSON that conforms to this schema:\n\n\
306 ```json\n{}\n```\n\n\
307 Rules:\n\
308 - Output ONLY the JSON object, no additional text before or after\n\
309 - Do NOT wrap your response in markdown code blocks (no ```json)\n\
310 - All required fields must be present\n\
311 - Field types must match the schema exactly",
312 schema_str
313 ))
314 }
315
316 fn format_example_instruction(example: &serde_json::Value) -> Option<String> {
318 let example_str = match serde_json::to_string_pretty(example) {
319 Ok(s) => s,
320 Err(e) => {
321 tracing::warn!(
322 "Failed to serialize from_example for prompt injection: {}",
323 e
324 );
325 return None;
326 }
327 };
328 Some(format!(
329 "\n\n---\n\
330 CRITICAL OUTPUT REQUIREMENT:\n\
331 Your response MUST be valid JSON matching this exact structure:\n\n\
332 ```json\n{}\n```\n\n\
333 Rules:\n\
334 - Output ONLY the JSON object, no additional text\n\
335 - Do NOT wrap in markdown code blocks (no ```json)\n\
336 - All keys shown above must be present\n\
337 - Value types must match (strings, numbers, arrays, objects)",
338 example_str
339 ))
340 }
341
342 #[instrument(skip(self, bindings, datastore, output_policy), fields(action_type = %action_type(action)))]
347 pub async fn execute(
348 &self,
349 task_id: &Arc<str>,
350 action: &TaskAction,
351 bindings: &ResolvedBindings,
352 datastore: &RunContext,
353 output_policy: Option<&OutputPolicy>,
354 ) -> Result<String, NikaError> {
355 debug!("Running task action");
356 match action {
357 TaskAction::Infer { infer } => {
358 self.run_infer(task_id, infer, bindings, datastore, output_policy)
359 .await
360 }
361 TaskAction::Exec { exec: e } => self.run_exec(task_id, e, bindings, datastore).await,
362 TaskAction::Fetch { fetch } => {
363 self.run_fetch(task_id, fetch, bindings, datastore).await
364 }
365 TaskAction::Invoke { invoke } => {
366 self.run_invoke(task_id, invoke, bindings, datastore).await
367 }
368 TaskAction::Agent { agent } => {
369 self.run_agent(task_id, agent, bindings, datastore, output_policy)
370 .await
371 }
372 }
373 }
374
375 pub(super) fn get_rig_provider(&self, name: &str) -> Result<RigProvider, NikaError> {
380 use dashmap::mapref::entry::Entry;
381
382 let canonical = crate::core::find_provider(name)
385 .map(|p| p.id)
386 .unwrap_or(name);
387
388 match self.rig_provider_cache.entry(canonical.to_string()) {
389 Entry::Occupied(e) => Ok(e.get().clone()),
390 Entry::Vacant(e) => {
391 let provider = RigProvider::from_name(name)?;
392 e.insert(provider.clone());
393 self.event_log.emit(EventKind::ProviderInitialized {
395 provider: canonical.to_string(),
396 model: provider.default_model().to_string(),
397 cached: false,
398 });
399 Ok(provider)
400 }
401 }
402 }
403
404 pub fn default_provider(&self) -> &str {
406 &self.default_provider
407 }
408
409 pub(super) async fn get_mcp_client(&self, name: &str) -> Result<Arc<McpClient>, NikaError> {
417 self.mcp_pool.get_or_connect(name).await.map_err(Into::into)
418 }
419
420 pub async fn shutdown_mcp(&self) {
425 self.mcp_pool.shutdown_all().await;
426 }
427}
428
429pub(super) fn action_type(action: &TaskAction) -> &'static str {
431 match action {
432 TaskAction::Infer { .. } => "infer",
433 TaskAction::Exec { .. } => "exec",
434 TaskAction::Fetch { .. } => "fetch",
435 TaskAction::Invoke { .. } => "invoke",
436 TaskAction::Agent { .. } => "agent",
437 }
438}