#![warn(missing_docs)]
pub mod blueprints;
pub mod config;
pub mod data;
pub mod doctor;
pub mod enhance_log;
pub mod enhance_settings;
pub mod issues;
pub mod operator_ws;
pub mod worker;
pub use blueprints::{build_blueprints_router, build_blueprints_router_with_refs};
pub use enhance_log::build_enhance_log_router;
pub use enhance_settings::build_enhance_settings_router;
pub use issues::{build_issues_router, GetIssueResponse, PostIssueRequest, PostIssueResponse};
pub use operator_ws::{
operators_create, operators_delete, operators_info, operators_ws_connect, ClientMsg,
OperatorSessionEntry, ServerMsg, WSOperatorSession,
};
pub use worker::{worker_prompt, worker_result, PromptQuery, WorkerResultReq};
use axum::{
extract::State,
http::{header::AUTHORIZATION, HeaderMap, StatusCode},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use mlua_swarm::application::{BlueprintRef, TaskApplication};
use mlua_swarm::blueprint::store::BlueprintStore;
use mlua_swarm::service::TaskLaunchService;
use mlua_swarm::{
CapToken, Compiler, Engine, LayerRegistry, LuaInProcessSpawnerFactory, MainAIMiddleware,
OperatorDelegateMiddleware, OperatorSpawnerFactory, Role, RustFnInProcessSpawnerFactory,
SeniorEscalationMiddleware, SpawnerRegistry, SubprocessProcessSpawnerFactory,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
#[derive(Default)]
pub struct SessionStore {
pub map: HashMap<String, CapToken>,
}
#[derive(Clone)]
pub struct AppState {
pub engine: Engine,
pub sessions: Arc<Mutex<SessionStore>>,
pub task_app: Arc<TaskApplication>,
pub ws_operator_factory: Option<Arc<OperatorSpawnerFactory>>,
pub data_store: Arc<dyn mlua_swarm::store::output::OutputStore>,
pub operator_sessions:
Arc<Mutex<HashMap<String, Arc<crate::operator_ws::login::OperatorSessionEntry>>>>,
pub roles_to_sid: Arc<Mutex<HashMap<String, String>>>,
}
pub fn build_router(engine: Engine) -> Router {
build_router_with(engine, default_registry(), None)
}
pub fn default_layer_registry() -> LayerRegistry {
LayerRegistry::new()
.with_hint("main_ai", |_engine| Arc::new(MainAIMiddleware::new()))
.with_hint("senior_escalation", |_engine| {
Arc::new(SeniorEscalationMiddleware::new())
})
.with_hint("operator_delegate", |_engine| {
Arc::new(OperatorDelegateMiddleware::new())
})
}
pub fn build_router_with(
engine: Engine,
registry: SpawnerRegistry,
store: Option<Arc<dyn BlueprintStore>>,
) -> Router {
build_router_with_ws_factory(engine, registry, store, None)
}
pub fn build_router_with_ws_factory(
engine: Engine,
registry: SpawnerRegistry,
store: Option<Arc<dyn BlueprintStore>>,
ws_operator_factory: Option<Arc<OperatorSpawnerFactory>>,
) -> Router {
build_router_with_ws_factory_and_output(engine, registry, store, ws_operator_factory, None)
}
pub fn build_router_with_ws_factory_and_output(
engine: Engine,
registry: SpawnerRegistry,
store: Option<Arc<dyn BlueprintStore>>,
ws_operator_factory: Option<Arc<OperatorSpawnerFactory>>,
output_store: Option<Arc<dyn mlua_swarm::store::output::OutputStore>>,
) -> Router {
let compiler = Compiler::new(registry);
let launch = Arc::new(TaskLaunchService::new(engine.clone(), compiler));
let task_app = Arc::new(match store {
Some(s) => TaskApplication::new(launch, s),
None => TaskApplication::new_inline_only(launch),
});
let data_store: Arc<dyn mlua_swarm::store::output::OutputStore> = match output_store {
Some(s) => s,
None => Arc::new(mlua_swarm::store::output::InMemoryOutputStore::new()),
};
let state = AppState {
engine,
sessions: Arc::new(Mutex::new(SessionStore::default())),
task_app,
ws_operator_factory,
data_store,
operator_sessions: Arc::new(Mutex::new(HashMap::new())),
roles_to_sid: Arc::new(Mutex::new(HashMap::new())),
};
Router::new()
.route("/v1/healthz", get(healthz))
.route(
"/v1/sessions",
post(sessions_attach).delete(sessions_detach),
)
.route("/v1/tasks", post(tasks_start))
.route("/v1/operators", post(operators_create))
.route("/v1/operators/:sid/ws", get(operators_ws_connect))
.route(
"/v1/operators/:sid",
get(operators_info).delete(operators_delete),
)
.route("/v1/worker/prompt", get(worker::worker_prompt))
.route("/v1/worker/result", post(worker::worker_result))
.route("/v1/worker/submit", post(worker::worker_submit))
.route("/v1/data/emit", post(data::data_emit))
.route(
"/v1/data/:key",
get(data::data_get).post(data::data_emit_named),
)
.with_state(state)
}
pub fn default_registry() -> SpawnerRegistry {
let rustfn_factory =
mlua_swarm::worker::baseline::extend_with_baseline(RustFnInProcessSpawnerFactory::new());
let mut reg = SpawnerRegistry::new();
reg.register::<SubprocessProcessSpawnerFactory>(Arc::new(SubprocessProcessSpawnerFactory));
reg.register::<RustFnInProcessSpawnerFactory>(Arc::new(rustfn_factory));
reg.register::<OperatorSpawnerFactory>(Arc::new(OperatorSpawnerFactory::new()));
reg
}
pub fn default_registry_with_enhance_flow() -> SpawnerRegistry {
let lua_factory =
mlua_swarm::enhance::blueprint::extend_factory(LuaInProcessSpawnerFactory::new());
let agent_block_factory =
mlua_swarm::worker::agent_block::AgentBlockInProcessSpawnerFactory::new();
let rustfn_factory =
mlua_swarm::worker::baseline::extend_with_baseline(RustFnInProcessSpawnerFactory::new());
let mut reg = SpawnerRegistry::new();
reg.register::<SubprocessProcessSpawnerFactory>(Arc::new(SubprocessProcessSpawnerFactory));
reg.register::<RustFnInProcessSpawnerFactory>(Arc::new(rustfn_factory));
reg.register::<LuaInProcessSpawnerFactory>(Arc::new(lua_factory));
reg.register::<mlua_swarm::worker::agent_block::AgentBlockInProcessSpawnerFactory>(Arc::new(
agent_block_factory,
));
reg.register::<OperatorSpawnerFactory>(Arc::new(OperatorSpawnerFactory::new()));
reg
}
async fn healthz() -> &'static str {
"ok"
}
#[derive(Deserialize)]
struct AttachReq {
agent_id: String,
role: String,
ttl_secs: u64,
}
#[derive(Serialize)]
struct AttachResp {
session_id: String,
role: String,
}
async fn sessions_attach(
State(state): State<AppState>,
Json(req): Json<AttachReq>,
) -> Result<Json<AttachResp>, ApiError> {
let role = parse_role(&req.role)?;
let token = state
.engine
.attach(req.agent_id, role, Duration::from_secs(req.ttl_secs))
.await
.map_err(ApiError::engine)?;
let sid = token.nonce.clone();
state.sessions.lock().await.map.insert(sid.clone(), token);
Ok(Json(AttachResp {
session_id: sid,
role: req.role,
}))
}
async fn sessions_detach(
State(state): State<AppState>,
headers: HeaderMap,
) -> Result<StatusCode, ApiError> {
let sid = extract_bearer(&headers)?;
let token = take_session_token(&state, &sid).await?;
state
.engine
.detach(&token)
.await
.map_err(ApiError::engine)?;
Ok(StatusCode::NO_CONTENT)
}
#[derive(Deserialize)]
struct FlowTasksReq {
blueprint: BlueprintRef,
init_ctx: Value,
#[serde(default)]
ttl_secs: Option<u64>,
#[serde(default)]
operator: Option<OperatorReq>,
#[serde(default)]
operator_sid: Option<String>,
}
#[derive(Deserialize, Default)]
struct OperatorReq {
#[serde(default)]
kind: Option<String>,
#[serde(default)]
id: Option<String>,
#[serde(default)]
spawn_hook_id: Option<String>,
#[serde(default)]
senior_bridge_id: Option<String>,
#[serde(default)]
operator_backend_id: Option<String>,
#[serde(default)]
per_agent_kinds: Option<HashMap<String, String>>,
}
fn parse_operator_kind_str(s: &str) -> Result<mlua_swarm::OperatorKind, ApiError> {
use mlua_swarm::OperatorKind;
match s {
"main_ai" => Ok(OperatorKind::MainAi),
"composite" => Ok(OperatorKind::Composite),
"automate" => Ok(OperatorKind::Automate),
other => Err(ApiError::bad_request(format!(
"operator kind: unknown value '{other}' (expected main_ai|automate|composite)"
))),
}
}
#[derive(Serialize)]
struct FlowTasksResp {
final_ctx: Value,
bound_version: Option<String>,
effective_ttl_secs: u64,
ttl_source: TtlSource,
}
#[derive(Serialize, Clone, Copy, Debug, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum TtlSource {
RequestBody,
BpMetadata,
ServerDefault,
}
async fn tasks_start(
State(state): State<AppState>,
Json(req): Json<FlowTasksReq>,
) -> Result<Json<FlowTasksResp>, ApiError> {
let resp = run_flow_form(&state, req).await?;
Ok(Json(resp))
}
async fn run_flow_form(state: &AppState, req: FlowTasksReq) -> Result<FlowTasksResp, ApiError> {
use mlua_swarm::application::{BlueprintRef as AppBlueprintRef, TaskApplicationInput};
use mlua_swarm::Application;
use mlua_swarm::OperatorKind;
let mut op_req = req.operator.unwrap_or_default();
if let Some(sid) = &req.operator_sid {
let known_ids = state.engine.list_operator_ids().await;
if !known_ids.iter().any(|id| id == sid) {
return Err(ApiError::bad_request(format!(
"operator_sid: no such registered operator session '{sid}'"
)));
}
op_req.operator_backend_id = Some(sid.clone());
}
let operator_kind = op_req
.kind
.as_deref()
.map(parse_operator_kind_str)
.transpose()?;
let operator_id = op_req.id.unwrap_or_else(|| "http-run".to_string());
let mut operator_kind_overrides: HashMap<String, OperatorKind> = HashMap::new();
for (agent, kind_str) in op_req.per_agent_kinds.take().unwrap_or_default() {
operator_kind_overrides.insert(agent, parse_operator_kind_str(&kind_str)?);
}
let blueprint: AppBlueprintRef = match req.blueprint {
AppBlueprintRef::Inline { value } => AppBlueprintRef::Inline { value },
AppBlueprintRef::Id { id, version } => AppBlueprintRef::Id { id, version },
};
let (ttl_secs, ttl_source) = match req.ttl_secs {
Some(v) => (v, TtlSource::RequestBody),
None => {
let (resolved_bp, _ver) = state
.task_app
.resolve(&blueprint)
.await
.map_err(|e| ApiError::bad_request(format!("bp resolve: {e}")))?;
match resolved_bp.metadata.default_run_ttl_secs {
Some(v) => (v, TtlSource::BpMetadata),
None => (default_run_ttl(), TtlSource::ServerDefault),
}
}
};
let out = state
.task_app
.handle(TaskApplicationInput {
blueprint,
operator_id: operator_id.clone(),
role: Role::Operator,
ttl: Duration::from_secs(ttl_secs),
init_ctx: req.init_ctx,
operator_kind,
bridge_id: op_req.senior_bridge_id,
hook_id: op_req.spawn_hook_id,
operator_backend_id: op_req.operator_backend_id,
operator_kind_overrides,
})
.await
.map_err(|e| ApiError::bad_request(format!("run: {e}")))?;
Ok(FlowTasksResp {
final_ctx: out.final_ctx,
bound_version: out.bound_version.map(|v| format!("{:?}", v)),
effective_ttl_secs: ttl_secs,
ttl_source,
})
}
async fn take_session_token(state: &AppState, sid: &str) -> Result<CapToken, ApiError> {
state
.sessions
.lock()
.await
.map
.remove(sid)
.ok_or_else(|| ApiError::not_found(format!("session: {sid}")))
}
fn extract_bearer(headers: &HeaderMap) -> Result<String, ApiError> {
let v = headers
.get(AUTHORIZATION)
.ok_or_else(|| ApiError::bad_request("missing Authorization header".into()))?
.to_str()
.map_err(|_| ApiError::bad_request("invalid Authorization header encoding".into()))?;
let sid = v
.strip_prefix("Bearer ")
.ok_or_else(|| ApiError::bad_request("Authorization must be 'Bearer <sid>'".into()))?
.trim();
if sid.is_empty() {
return Err(ApiError::bad_request("Bearer sid is empty".into()));
}
Ok(sid.to_string())
}
fn parse_role(s: &str) -> Result<Role, ApiError> {
match s.to_ascii_lowercase().as_str() {
"operator" => Ok(Role::Operator),
"worker" => Ok(Role::Worker),
"observer" => Ok(Role::Observer),
"senior" => Ok(Role::Senior),
other => Err(ApiError::bad_request(format!("unknown role: {other}"))),
}
}
#[derive(Debug)]
pub struct ApiError {
status: StatusCode,
message: String,
}
impl ApiError {
pub fn engine(e: impl std::fmt::Display) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: format!("engine: {e}"),
}
}
pub fn not_found(m: String) -> Self {
Self {
status: StatusCode::NOT_FOUND,
message: m,
}
}
pub fn bad_request(m: String) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
message: m,
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
(self.status, Json(json!({"error": self.message}))).into_response()
}
}
fn default_run_ttl() -> u64 {
1800
}
#[cfg(test)]
fn resolve_ttl_from_metadata(metadata_ttl: Option<u64>) -> u64 {
metadata_ttl.unwrap_or_else(default_run_ttl)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ttl_cascade_request_body_wins_over_metadata() {
let req_ttl: Option<u64> = Some(100);
let metadata_ttl: Option<u64> = Some(3600);
let effective = match req_ttl {
Some(v) => v,
None => resolve_ttl_from_metadata(metadata_ttl),
};
assert_eq!(
effective, 100,
"request body ttl_secs=100 must win over metadata=3600 (cascade priority (1) > (2))"
);
}
#[test]
fn ttl_cascade_metadata_used_when_body_missing() {
let req_ttl: Option<u64> = None;
let metadata_ttl: Option<u64> = Some(3600);
let effective = match req_ttl {
Some(v) => v,
None => resolve_ttl_from_metadata(metadata_ttl),
};
assert_eq!(
effective, 3600,
"body None + metadata=3600 must resolve to 3600 (cascade (2))"
);
}
#[test]
fn ttl_cascade_server_default_when_both_missing() {
let req_ttl: Option<u64> = None;
let metadata_ttl: Option<u64> = None;
let effective = match req_ttl {
Some(v) => v,
None => resolve_ttl_from_metadata(metadata_ttl),
};
assert_eq!(
effective,
default_run_ttl(),
"body None + metadata None must fall back to default_run_ttl() = 1800s"
);
assert_eq!(effective, 1800, "default_run_ttl() literal = 1800s");
}
#[test]
fn resolve_ttl_from_metadata_none_returns_server_default() {
assert_eq!(resolve_ttl_from_metadata(None), 1800);
}
#[test]
fn resolve_ttl_from_metadata_some_returns_value() {
assert_eq!(resolve_ttl_from_metadata(Some(7200)), 7200);
assert_eq!(resolve_ttl_from_metadata(Some(60)), 60);
}
}