mod decompose;
mod extract;
#[cfg(test)]
mod tests;
#[cfg(test)]
mod tests_extract_e2e;
#[cfg(test)]
mod tests_extraction_e2e;
#[cfg(test)]
mod tests_wiremock;
mod verbs;
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
use std::sync::Arc;
use dashmap::DashMap;
use tokio_util::sync::CancellationToken;
use tracing::{debug, instrument};
use crate::ast::output::{OutputFormat, OutputPolicy, SchemaRef};
use crate::ast::{McpConfigInline, TaskAction};
use crate::binding::ResolvedBindings;
use crate::error::NikaError;
use crate::event::EventLog;
use crate::mcp::{McpClient, McpClientPool};
use crate::media::CasStore;
use crate::provider::rig::RigProvider;
use crate::runtime::boot::PolicyConfig;
use crate::runtime::builtin::media::context::MediaToolContext;
use crate::runtime::policy::PolicyEnforcer;
use crate::runtime::BuiltinToolRouter;
use crate::runtime::SkillInjector;
use crate::store::RunContext;
use crate::tools::{PermissionMode, ToolContext};
use crate::util::{CONNECT_TIMEOUT, FETCH_TIMEOUT, REDIRECT_LIMIT};
#[derive(Clone)]
pub struct TaskExecutor {
http_client: reqwest::Client,
rig_provider_cache: Arc<DashMap<String, RigProvider>>,
mcp_pool: McpClientPool,
default_provider: Arc<str>,
default_model: Option<Arc<str>>,
event_log: EventLog,
builtin_router: Arc<BuiltinToolRouter>,
policy_enforcer: Arc<parking_lot::RwLock<PolicyEnforcer>>,
cancel_token: CancellationToken,
cas: Arc<CasStore>,
skill_injector: Arc<SkillInjector>,
skills_map: std::collections::HashMap<String, String>,
workflow_base_dir: std::path::PathBuf,
}
impl TaskExecutor {
pub fn new(
provider: &str,
model: Option<&str>,
mcp_configs: Option<FxHashMap<String, McpConfigInline>>,
event_log: EventLog,
) -> Self {
Self::with_policy(provider, model, mcp_configs, event_log, None)
}
pub fn with_policy(
provider: &str,
model: Option<&str>,
mcp_configs: Option<FxHashMap<String, McpConfigInline>>,
event_log: EventLog,
policy_config: Option<PolicyConfig>,
) -> Self {
let http_client = reqwest::Client::builder()
.timeout(FETCH_TIMEOUT)
.connect_timeout(CONNECT_TIMEOUT)
.redirect(reqwest::redirect::Policy::limited(REDIRECT_LIMIT))
.user_agent(format!("nika/{}", env!("CARGO_PKG_VERSION")))
.build()
.unwrap_or_else(|e| {
tracing::error!("HTTP client build failed: {e}. Using default client.");
reqwest::Client::new()
});
let policy_enforcer = PolicyEnforcer::new(policy_config.unwrap_or_default());
let working_dir = std::env::current_dir().unwrap_or_else(|_| {
tracing::warn!("Failed to get current directory, using /tmp");
std::path::PathBuf::from("/tmp")
});
let tool_ctx = Arc::new(ToolContext::new(
working_dir.clone(),
PermissionMode::YoloMode,
));
let media_ctx = Arc::new(MediaToolContext::new(CasStore::workspace_default(
&working_dir,
)));
let cas = Arc::new(CasStore::workspace_default(&working_dir));
Self {
http_client,
rig_provider_cache: Arc::new(DashMap::new()),
mcp_pool: McpClientPool::with_configs(
event_log.clone(),
mcp_configs.unwrap_or_default(),
),
default_provider: provider.into(),
default_model: model.map(Into::into),
event_log,
builtin_router: Arc::new(BuiltinToolRouter::with_all_tools(tool_ctx, media_ctx)),
policy_enforcer: Arc::new(RwLock::new(policy_enforcer)),
cancel_token: CancellationToken::new(),
cas,
skill_injector: Arc::new(SkillInjector::new()),
skills_map: std::collections::HashMap::new(),
workflow_base_dir: working_dir,
}
}
pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
self.cancel_token = token;
self
}
pub fn with_skills(
mut self,
skills_map: std::collections::HashMap<String, String>,
base_dir: std::path::PathBuf,
) -> Self {
self.skills_map = skills_map;
self.workflow_base_dir = base_dir;
self
}
#[cfg(test)]
pub fn inject_mock_mcp_client(&self, name: &str) {
self.mcp_pool
.inject_mock(name, Arc::new(McpClient::mock(name)));
}
pub(super) fn build_json_schema_instruction(
output_policy: Option<&OutputPolicy>,
) -> Option<String> {
let policy = output_policy?;
if policy.format != OutputFormat::Json {
return None;
}
let schema_ref = policy.schema.as_ref()?;
let schema_json = match schema_ref {
SchemaRef::Inline(v) => v.clone(),
SchemaRef::File(_) => {
return Some(
"\n\n---\n\
CRITICAL OUTPUT REQUIREMENT:\n\
Your response MUST be valid JSON.\n\n\
Rules:\n\
- Output ONLY the JSON object, no additional text\n\
- Do NOT wrap in markdown code blocks (no ```json)\n\
- Ensure all JSON is properly formatted and valid"
.to_string(),
);
}
};
let schema_str = serde_json::to_string_pretty(&schema_json).unwrap_or_default();
Some(format!(
"\n\n---\n\
CRITICAL OUTPUT REQUIREMENT:\n\
Your response MUST be valid JSON that conforms to this schema:\n\n\
```json\n{}\n```\n\n\
Rules:\n\
- Output ONLY the JSON object, no additional text before or after\n\
- Do NOT wrap your response in markdown code blocks (no ```json)\n\
- All required fields must be present\n\
- Field types must match the schema exactly",
schema_str
))
}
#[instrument(skip(self, bindings, datastore, output_policy), fields(action_type = %action_type(action)))]
pub async fn execute(
&self,
task_id: &Arc<str>,
action: &TaskAction,
bindings: &ResolvedBindings,
datastore: &RunContext,
output_policy: Option<&OutputPolicy>,
) -> Result<String, NikaError> {
debug!("Running task action");
match action {
TaskAction::Infer { infer } => {
self.run_infer(task_id, infer, bindings, datastore, output_policy)
.await
}
TaskAction::Exec { exec: e } => self.run_exec(task_id, e, bindings, datastore).await,
TaskAction::Fetch { fetch } => {
self.run_fetch(task_id, fetch, bindings, datastore).await
}
TaskAction::Invoke { invoke } => {
self.run_invoke(task_id, invoke, bindings, datastore).await
}
TaskAction::Agent { agent } => {
self.run_agent(task_id, agent, bindings, datastore, output_policy)
.await
}
}
}
pub(super) fn get_rig_provider(&self, name: &str) -> Result<RigProvider, NikaError> {
use dashmap::mapref::entry::Entry;
match self.rig_provider_cache.entry(name.to_string()) {
Entry::Occupied(e) => Ok(e.get().clone()),
Entry::Vacant(e) => {
let provider = RigProvider::from_name(name)?;
e.insert(provider.clone());
Ok(provider)
}
}
}
pub fn default_provider(&self) -> &str {
&self.default_provider
}
pub(super) async fn get_mcp_client(&self, name: &str) -> Result<Arc<McpClient>, NikaError> {
self.mcp_pool.get_or_connect(name).await.map_err(Into::into)
}
}
pub(super) fn action_type(action: &TaskAction) -> &'static str {
match action {
TaskAction::Infer { .. } => "infer",
TaskAction::Exec { .. } => "exec",
TaskAction::Fetch { .. } => "fetch",
TaskAction::Invoke { .. } => "invoke",
TaskAction::Agent { .. } => "agent",
}
}