use crate::agent::{Agent, AgentOptions};
use crate::hub::get_hub;
use crate::run::literals::Literals;
use crate::run::run_input::{run_agent_input, RunAgentInputResponse};
use crate::run::{DirContext, RunBaseOptions, Runtime};
use crate::script::{BeforeAllResponse, DevaiCustom, FromValue};
use crate::{Error, Result};
use serde::Serialize;
use serde_json::Value;
use simple_fs::SPath;
use tokio::task::JoinSet;
use value_ext::JsonValueExt;
const DEFAULT_CONCURRENCY: usize = 1;
#[derive(Debug, Serialize, Default)]
pub struct RunCommandResponse {
pub outputs: Option<Vec<Value>>,
pub after_all: Option<Value>,
}
fn get_display_path(file_path: &str, dir_context: &DirContext) -> Result<SPath> {
let file_path = SPath::new(file_path)?;
if file_path.to_str().contains(".devai-base") {
Ok(file_path)
} else {
let spath = file_path.diff(dir_context.workspace_dir())?;
Ok(spath)
}
}
pub async fn run_command_agent(
runtime: &Runtime,
agent: Agent,
inputs: Option<Vec<Value>>,
run_base_options: &RunBaseOptions,
return_output_values: bool,
) -> Result<RunCommandResponse> {
let hub = get_hub();
let concurrency = agent.options().input_concurrency().unwrap_or(DEFAULT_CONCURRENCY);
let literals = Literals::from_dir_context_and_agent_path(runtime.dir_context(), &agent)?;
let BeforeAllResponse {
inputs,
before_all,
options: options_to_merge,
} = if let Some(before_all_script) = agent.before_all_script() {
let lua_engine = runtime.new_lua_engine()?;
let lua_scope = lua_engine.create_table()?;
let lua_inputs = inputs.clone().map(Value::Array).unwrap_or_default();
lua_scope.set("inputs", lua_engine.serde_to_lua_value(lua_inputs)?)?;
lua_scope.set("CTX", literals.to_lua(&lua_engine)?)?;
lua_scope.set("options", agent.options_as_ref())?;
let lua_value = lua_engine.eval(before_all_script, Some(lua_scope), Some(&[agent.file_dir()?.to_str()]))?;
let before_all_res = serde_json::to_value(lua_value)?;
match DevaiCustom::from_value(before_all_res)? {
FromValue::DevaiCustom(DevaiCustom::Skip { reason }) => {
let reason_msg = reason.map(|reason| format!(" (Reason: {reason})")).unwrap_or_default();
hub.publish(format!("-! DevAI Skip inputs at Before All section{reason_msg}"))
.await;
return Ok(RunCommandResponse::default());
}
FromValue::DevaiCustom(DevaiCustom::BeforeAllResponse(BeforeAllResponse {
inputs: inputs_ov,
before_all,
options,
})) => BeforeAllResponse {
inputs: inputs_ov.or(inputs),
before_all,
options,
},
FromValue::OriginalValue(value) => BeforeAllResponse {
inputs,
before_all: Some(value),
options: None,
},
}
} else {
BeforeAllResponse {
inputs,
before_all: None,
options: None,
}
};
let inputs = inputs.unwrap_or_else(|| vec![Value::Null]);
let before_all = before_all.unwrap_or_default();
let agent: Agent = match options_to_merge {
Some(options_to_merge) => {
let options_to_merge: AgentOptions = serde_json::from_value(options_to_merge)?;
let options_ov = agent.options_as_ref().merge_new(options_to_merge)?;
agent.new_merge(options_ov)?
}
None => agent,
};
let genai_info = get_genai_info(&agent);
let agent_path = match get_display_path(agent.file_path(), runtime.dir_context()) {
Ok(path) => path.to_string(),
Err(_) => agent.file_path().to_string(),
};
let model_str: &str = agent.model();
let model_resolved_str: &str = agent.model_resolved();
let model_name_message = if model_str != model_resolved_str {
format!("{model_str} ({model_resolved_str})")
} else {
model_resolved_str.to_string()
};
hub.publish(format!(
"\nRunning agent command: {}\n from: {}\n with model: {}{genai_info}",
agent.name(),
agent_path,
model_name_message
))
.await;
let mut captured_outputs: Option<Vec<(usize, Value)>> =
if agent.after_all_script().is_some() || return_output_values {
Some(Vec::new())
} else {
None
};
let mut join_set = JoinSet::new();
let mut in_progress = 0;
for (input_idx, input) in inputs.clone().into_iter().enumerate() {
let runtime_clone = runtime.clone();
let agent_clone = agent.clone();
let before_all_clone = before_all.clone();
let literals = literals.clone();
let base_run_config_clone = run_base_options.clone();
join_set.spawn(async move {
let run_input_response = run_command_agent_input(
input_idx,
&runtime_clone,
&agent_clone,
before_all_clone,
input,
&literals,
&base_run_config_clone,
)
.await?;
let run_input_value = run_input_response.map(|v| v.into_value()).unwrap_or_default();
let output = match DevaiCustom::from_value(run_input_value)? {
FromValue::DevaiCustom(DevaiCustom::Skip { reason }) => {
let reason_msg = reason.map(|reason| format!(" (Reason: {reason})")).unwrap_or_default();
hub.publish(format!("-! DevAI Skip input at Output stage{reason_msg}")).await;
Value::Null
}
FromValue::DevaiCustom(other) => {
return Err(Error::custom(format!(
"devai custom '{}' not supported at the Output stage",
other.as_ref()
)))
}
FromValue::OriginalValue(value) => value,
};
Ok((input_idx, output))
});
in_progress += 1;
if in_progress >= concurrency {
if let Some(res) = join_set.join_next().await {
in_progress -= 1;
match res {
Ok(Ok((input_idx, output))) => {
if let Some(outputs_vec) = &mut captured_outputs {
outputs_vec.push((input_idx, output));
}
}
Ok(Err(e)) => return Err(e),
Err(e) => return Err(Error::custom(format!("Error while running input. Cause {e}"))),
}
}
}
}
while in_progress > 0 {
if let Some(res) = join_set.join_next().await {
in_progress -= 1;
match res {
Ok(Ok((input_idx, output))) => {
if let Some(outputs_vec) = &mut captured_outputs {
outputs_vec.push((input_idx, output));
}
}
Ok(Err(e)) => return Err(e),
Err(e) => return Err(Error::custom(format!("Error while remaining input. Cause {e}"))),
}
}
}
let outputs = if let Some(mut captured_outputs) = captured_outputs {
captured_outputs.sort_by_key(|(idx, _)| *idx);
Some(captured_outputs.into_iter().map(|(_, v)| v).collect::<Vec<_>>())
} else {
None
};
let after_all = if let Some(after_all_script) = agent.after_all_script() {
let outputs_value = if let Some(outputs) = outputs.as_ref() {
Value::Array(outputs.clone())
} else {
Value::Null
};
let lua_engine = runtime.new_lua_engine()?;
let lua_scope = lua_engine.create_table()?;
let inputs = Value::Array(inputs);
lua_scope.set("inputs", lua_engine.serde_to_lua_value(inputs)?)?;
lua_scope.set("outputs", lua_engine.serde_to_lua_value(outputs_value)?)?;
lua_scope.set("before_all", lua_engine.serde_to_lua_value(before_all)?)?;
lua_scope.set("CTX", literals.to_lua(&lua_engine)?)?;
lua_scope.set("options", agent.options_as_ref())?;
let lua_value = lua_engine.eval(after_all_script, Some(lua_scope), Some(&[agent.file_dir()?.to_str()]))?;
Some(serde_json::to_value(lua_value)?)
} else {
None
};
Ok(RunCommandResponse { after_all, outputs })
}
async fn run_command_agent_input(
input_idx: usize,
runtime: &Runtime,
agent: &Agent,
before_all: Value,
input: impl Serialize,
literals: &Literals,
run_base_options: &RunBaseOptions,
) -> Result<Option<RunAgentInputResponse>> {
let hub = get_hub();
let input = serde_json::to_value(input)?;
let label = get_input_label(&input).unwrap_or_else(|| format!("input index: {input_idx}"));
hub.publish(format!("\n==== Running input: {}", label)).await;
let run_response = run_agent_input(runtime, agent, before_all, &label, input, literals, run_base_options).await?;
if let Some(response_txt) = run_response.as_ref().and_then(|r| r.as_str()) {
hub.publish(format!("-> Agent Output:\n{response_txt}")).await;
}
hub.publish(format!("-- DONE (input: {})", label)).await;
Ok(run_response)
}
#[cfg(test)]
pub async fn run_command_agent_input_for_test(
input_idx: usize,
runtime: &Runtime,
agent: &Agent,
before_all: Value,
input: impl Serialize,
run_base_options: &RunBaseOptions,
) -> Result<Option<RunAgentInputResponse>> {
let literals = Literals::from_dir_context_and_agent_path(runtime.dir_context(), agent)?;
run_command_agent_input(
input_idx,
runtime,
agent,
before_all,
input,
&literals,
run_base_options,
)
.await
}
fn get_input_label(input: &Value) -> Option<String> {
const LABEL_KEYS: &[&str] = &["path", "name", "label", "_label"];
for &key in LABEL_KEYS {
if let Ok(value) = input.x_get::<String>(key) {
return Some(value);
}
}
None
}
fn get_genai_info(agent: &Agent) -> String {
let mut genai_infos: Vec<String> = vec![];
if let Some(temp) = agent.options().temperature() {
genai_infos.push(format!("temperature: {temp}"));
}
if genai_infos.is_empty() {
"".to_string()
} else {
format!(" ({})", genai_infos.join(", "))
}
}
#[cfg(test)]
#[path = "../_tests/tests_run_agent_llm.rs"]
mod tests_run_agent_llm;
#[cfg(test)]
#[path = "../_tests/tests_run_agent_script.rs"]
mod tests_run_agent_script;