use crate::Result;
use crate::agent::Agent;
use crate::hub::get_hub;
use crate::model::base::DbBmc;
use crate::model::{
EndState, Id, LogBmc, LogForCreate, LogKind, ModelManager, RunBmc, RunForCreate, RunForUpdate, Stage, TaskBmc,
TaskForCreate, TaskForUpdate, TypedContent,
};
use crate::run::ModelPricing;
use crate::runtime::Runtime;
use derive_more::From;
use genai::ModelIden;
use serde_json::Value;
use uuid::Uuid;
#[derive(Debug, From)]
pub struct RtModel<'a> {
runtime: &'a Runtime,
}
impl<'a> RtModel<'a> {
pub(super) fn new(runtime: &'a Runtime) -> Self {
Self { runtime }
}
fn mm(&self) -> &ModelManager {
self.runtime.mm()
}
}
impl<'a> RtModel<'a> {
pub async fn create_run(&self, parent_uid: Option<Uuid>, agent: &Agent) -> Result<Id> {
let hub = get_hub();
let agent_path = match self.runtime.dir_context().get_display_path(agent.file_path()) {
Ok(path) => path.to_string(),
Err(_) => agent.file_path().to_string(),
};
let agent_name = agent.name();
let parent_id = if let Some(uid) = parent_uid {
Some(RunBmc::get_id_for_uid(self.mm(), uid)?)
} else {
None
};
let run_id = RunBmc::create(
self.mm(),
RunForCreate {
parent_id,
agent_name: Some(agent_name.to_string()),
agent_path: Some(agent_path.to_string()),
has_task_stages: Some(agent.has_task_stages()),
has_prompt_parts: Some(agent.has_prompt_parts()),
},
)?;
hub.publish(format!(
"\n======= RUNNING: {agent_name}\n Agent path: {agent_path}",
))
.await;
Ok(run_id)
}
pub async fn update_run_model_and_concurrency(
&self,
run_id: Id,
model_name: &str,
concurrency: usize,
) -> Result<()> {
let run_u = RunForUpdate {
model: Some(model_name.to_string()),
concurrency: Some(concurrency as i32),
..Default::default()
};
RunBmc::update(self.mm(), run_id, run_u)?;
Ok(())
}
pub async fn update_run_flow_redo_count(&self, run_id: Id, flow_redo_count: i32) -> Result<()> {
let run_u = RunForUpdate {
flow_redo_count: Some(flow_redo_count),
..Default::default()
};
RunBmc::update(self.mm(), run_id, run_u)?;
Ok(())
}
pub fn set_run_end_error(&self, run_id: Id, stage: Option<Stage>, err: &crate::Error) -> Result<()> {
RunBmc::set_end_error(self.mm(), run_id, stage, err)?;
Ok(())
}
pub fn set_run_end_state_to_skip(&self, run_id: Id) -> Result<()> {
RunBmc::update(
self.mm(),
run_id,
RunForUpdate {
end_state: Some(EndState::Skip),
..Default::default()
},
)?;
Ok(())
}
pub async fn rec_skip_run(&self, run_id: Id, stage: Stage, reason: Option<String>) -> Result<()> {
let mm = self.mm();
let reason_txt = reason.as_ref().map(|r| format!(" (Reason: {r})")).unwrap_or_default();
RunBmc::update(
mm,
run_id,
RunForUpdate {
end_skip_reason: reason.clone(),
..Default::default()
},
)?;
let log_c = LogForCreate {
run_id,
task_id: None,
step: None,
stage: Some(stage),
message: reason,
kind: Some(LogKind::AgentSkip),
};
LogBmc::create(mm, log_c)?;
get_hub()
.publish(format!("Aipack Skip input at {stage} stage: {reason_txt}"))
.await;
Ok(())
}
}
impl<'a> RtModel<'a> {
pub async fn create_task(&self, run_id: Id, idx: usize, input: &Value) -> Result<Id> {
let input_content = TypedContent::from_value(input);
let task_c = TaskForCreate {
run_id,
idx: idx as i64,
label: None,
input_content: Some(input_content),
};
let id = TaskBmc::create(self.mm(), task_c)?;
Ok(id)
}
pub async fn create_tasks_batch(&self, run_id: Id, items: Vec<TaskForCreate>) -> Result<Vec<Id>> {
debug_assert!(items.iter().all(|t| t.run_id == run_id));
let ids = TaskBmc::create_batch(self.mm(), items)?;
Ok(ids)
}
pub async fn update_task_model_ov(&self, _run_id: Id, task_id: Id, model_name_ov: &str) -> Result<()> {
let task_u = TaskForUpdate {
model_ov: Some(model_name_ov.to_string()),
..Default::default()
};
TaskBmc::update(self.mm(), task_id, task_u)?;
Ok(())
}
pub async fn update_task_model_pricing(&self, _run_id: Id, task_id: Id, pricing: &ModelPricing) -> Result<()> {
let task_u = TaskForUpdate {
pricing_model: Some(pricing.name.to_string()),
pricing_input: Some(pricing.input_normal),
pricing_input_cached: pricing.input_cached,
pricing_output: Some(pricing.output_normal),
..Default::default()
};
TaskBmc::update(self.mm(), task_id, task_u)?;
Ok(())
}
pub async fn update_task_usage(
&self,
_run_id: Id,
task_id: Id,
usage: &genai::chat::Usage,
model_upstream: &ModelIden,
) -> Result<()> {
let mut task_u = TaskForUpdate::from_usage(usage);
task_u.model_upstream = Some(model_upstream.model_name.to_string());
TaskBmc::update(self.mm(), task_id, task_u)?;
Ok(())
}
pub async fn update_task_cost(
&self,
run_id: Id,
task_id: Id,
cost: f64,
cost_cache_write: Option<f64>,
cost_cache_saving: Option<f64>,
) -> Result<()> {
let task_u = TaskForUpdate {
cost: Some(cost),
cost_cache_write,
cost_cache_saving,
..Default::default()
};
TaskBmc::update(self.mm(), task_id, task_u)?;
let tasks = TaskBmc::list_for_run(self.mm(), run_id)?;
let total_cost: f64 = tasks.iter().filter_map(|t| t.cost).sum();
let run_u = RunForUpdate {
total_cost: Some(total_cost),
..Default::default()
};
RunBmc::update(self.mm(), run_id, run_u)?;
Ok(())
}
pub async fn update_task_output(&self, task_id: Id, output: &Value) -> Result<()> {
let output_content = TypedContent::from_value(output);
if output_content.content.is_none() && output_content.display.is_none() {
return Ok(()); }
TaskBmc::update_output(self.mm(), task_id, output_content)?;
Ok(())
}
pub fn set_task_end_error(&self, _run_id: Id, task_id: Id, stage: Option<Stage>, err: &crate::Error) -> Result<()> {
TaskBmc::set_end_error_no_end(self.mm(), task_id, stage, err)?;
Ok(())
}
pub fn set_task_end_state_to_skip(&self, _run_id: Id, task_id: Id) -> Result<()> {
TaskBmc::update(
self.mm(),
task_id,
TaskForUpdate {
end_state: Some(EndState::Skip),
..Default::default()
},
)?;
Ok(())
}
pub async fn rec_skip_task(&self, run_id: Id, task_id: Id, stage: Stage, reason: Option<String>) -> Result<()> {
let mm = self.mm();
let reason_txt = reason.as_ref().map(|r| format!(" (Reason: {r})")).unwrap_or_default();
TaskBmc::update(
mm,
task_id,
TaskForUpdate {
end_skip_reason: reason.clone(),
..Default::default()
},
)?;
let log_c = LogForCreate {
run_id,
task_id: Some(task_id),
step: None,
stage: Some(stage),
message: reason,
kind: Some(LogKind::AgentSkip),
};
LogBmc::create(mm, log_c)?;
get_hub()
.publish(format!("Aipack Skip input at {stage} stage: {reason_txt}"))
.await;
Ok(())
}
}