use super::super::diagnostics::{collect_runtime_diagnostics, diagnostics_system_note};
use super::super::intent_registry::IntentRegistry;
use super::super::routing::{persist_model_selection_audit, select_routed_model_with_audit};
use super::tool_prune::prune_tool_definitions;
use super::types::{InferenceInput, PreparedContextSnapshot, PreparedInference};
use tracing::Instrument;
pub(crate) async fn prepare_inference(
input: &InferenceInput<'_>,
) -> Result<PreparedInference, String> {
let state = input.state;
let features = roboticus_llm::extract_features(input.user_content, 0, 1);
let complexity = roboticus_llm::classify_complexity(&features);
let model_audit = async {
let audit = select_routed_model_with_audit(state, input.user_content).await;
tracing::info!(model = %audit.selected_model, "model selected for inference");
audit
}
.instrument(tracing::info_span!("model_selection"))
.await;
let model = model_audit.selected_model.clone();
let complexity_label = format!("{complexity:?}");
persist_model_selection_audit(
state,
input.turn_id,
input.session_id,
input.channel_label,
Some(&complexity_label),
input.user_content,
&model_audit,
)
.await;
if let Err(e) = roboticus_db::sessions::update_model(&state.db, input.session_id, &model) {
tracing::warn!(session_id = %input.session_id, model = %model, error = %e, "failed to update session model");
}
let (tier, embedding_client) = {
let llm = state.llm.read().await;
let tier = llm
.providers
.get_by_model(&model)
.map(|p| p.tier)
.unwrap_or_else(|| roboticus_llm::tier::classify(&model));
(tier, llm.embedding.clone())
};
let query_embedding = embedding_client
.embed_single(input.user_content)
.await
.inspect_err(|e| {
tracing::warn!(error = %e, "embedding generation failed, RAG retrieval will be skipped")
})
.ok();
let cache_hash = roboticus_llm::SemanticCache::compute_hash("", "", input.user_content);
let complexity_level = roboticus_agent::context::determine_level(complexity);
let ann_ref = if state.ann_index.is_built() {
Some(&state.ann_index)
} else {
None
};
let retrieval_output = state.retriever.retrieve_with_metrics(
&state.db,
input.session_id,
input.user_content,
query_embedding.as_deref(),
complexity_level,
ann_ref,
);
let memories = retrieval_output.text;
let retrieval_metrics = retrieval_output.metrics;
let history_messages =
roboticus_db::sessions::list_messages(&state.db, input.session_id, Some(50))
.map_err(|e| format!("failed to load conversation history: {e}"))?;
let previous_assistant = history_messages
.iter()
.rev()
.find(|m| m.role == "assistant")
.map(|m| m.content.clone());
let current_topic = history_messages
.iter()
.rev()
.find_map(|m| m.topic_tag.as_deref())
.unwrap_or("topic-1");
let all_except_last: Vec<&roboticus_db::sessions::Message> = history_messages
.iter()
.rev()
.skip(1) .collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let mut history: Vec<roboticus_llm::format::UnifiedMessage> = Vec::new();
let mut off_topic_block: Vec<&roboticus_db::sessions::Message> = Vec::new();
for msg in &all_except_last {
let is_current_topic = msg.topic_tag.as_deref().is_none_or(|t| t == current_topic);
if is_current_topic {
if !off_topic_block.is_empty() {
let first_user = off_topic_block
.iter()
.find(|m| m.role == "user")
.map(|m| m.content.as_str())
.unwrap_or("(earlier conversation)");
let tag = off_topic_block
.first()
.and_then(|m| m.topic_tag.as_deref())
.unwrap_or("earlier");
let summary = roboticus_agent::topic::summarize_topic_block(
tag,
off_topic_block.len(),
first_user,
);
history.push(roboticus_llm::format::UnifiedMessage {
role: "system".into(),
content: summary,
parts: None,
});
off_topic_block.clear();
}
history.push(roboticus_llm::format::UnifiedMessage {
role: msg.role.clone(),
content: msg.content.clone(),
parts: None,
});
} else {
off_topic_block.push(msg);
}
}
if !off_topic_block.is_empty() {
let first_user = off_topic_block
.iter()
.find(|m| m.role == "user")
.map(|m| m.content.as_str())
.unwrap_or("(earlier conversation)");
let tag = off_topic_block
.first()
.and_then(|m| m.topic_tag.as_deref())
.unwrap_or("earlier");
let summary =
roboticus_agent::topic::summarize_topic_block(tag, off_topic_block.len(), first_user);
history.push(roboticus_llm::format::UnifiedMessage {
role: "system".into(),
content: summary,
parts: None,
});
}
let model_for_api = roboticus_core::model::model_name(&model).to_string();
let system_prompt = if input.os_text.is_empty() {
format!(
"You are {name}, an autonomous AI agent (id: {id}). \
When asked who you are, always identify as {name}. \
Never reveal the underlying model name or claim to be a generic assistant.",
name = input.agent_name,
id = input.agent_id,
)
} else {
let mut prompt = input.os_text.clone();
if !input.firmware_text.is_empty() {
prompt.push_str("\n\n");
prompt.push_str(&input.firmware_text);
}
if history.is_empty() {
prompt.push_str(&format!(
"\n\n---\n## Session Start\n\
This is a NEW session. You are {name}. Respond in character from your very first message. \
Do not introduce yourself by listing capabilities or tools. Simply embody your role.\n---",
name = input.agent_name,
));
} else {
prompt.push_str(
"\n\n## Continuity\n\
This is a continuing conversation. Do not re-introduce yourself \
or restate your identity. Maintain natural conversational flow.",
);
}
prompt
};
let system_prompt = if let Some(ref wf_note) = input.delegation_workflow_note {
format!("{system_prompt}\nWorkflow: {wf_note}")
} else {
system_prompt
};
let all_tools = super::super::decomposition::build_all_tool_definitions(state).await;
let (tools, tool_search_stats) = prune_tool_definitions(
state,
all_tools,
query_embedding.as_deref(),
&embedding_client,
)
.await;
let tool_summary: Vec<(String, String)> = tools
.iter()
.map(|t| (t.name.clone(), t.description.clone()))
.collect();
let (workspace_path, delegation_enabled) = {
let cfg = input.state.config.read().await;
(
cfg.agent.workspace.display().to_string(),
cfg.agent.delegation_enabled,
)
};
let system_prompt = format!(
"{system_prompt}{}{}{}{}{}",
roboticus_agent::prompt::runtime_metadata_block(
env!("CARGO_PKG_VERSION"),
&input.primary_model,
&model,
&workspace_path,
),
roboticus_agent::prompt::behavioral_contract_block(),
roboticus_agent::prompt::operational_introspection_block(delegation_enabled),
roboticus_agent::prompt::subagent_orchestration_workflow_block(delegation_enabled),
roboticus_agent::prompt::tool_use_instructions(&tool_summary),
);
let system_prompt =
roboticus_agent::prompt::inject_hmac_boundary(&system_prompt, state.hmac_secret.as_ref());
if !roboticus_agent::prompt::verify_hmac_boundary(&system_prompt, state.hmac_secret.as_ref()) {
tracing::error!("HMAC boundary verification failed immediately after injection");
return Err("internal HMAC verification failure".into());
}
let context_budget_cfg = {
let cfg = input.state.config.read().await;
cfg.context_budget.clone()
};
let compacted_memories = roboticus_agent::compaction::compact_text(
&memories,
context_budget_cfg.l0 / 4, );
let mut messages = roboticus_agent::context::build_context_with_budget(
complexity_level,
&system_prompt,
&compacted_memories,
&history,
&context_budget_cfg,
);
match roboticus_db::checkpoint::load_checkpoint(&state.db, input.session_id) {
Ok(Some(cp)) => {
let mut checkpoint_note = format!(
"Session checkpoint restore (turn_count={}): {}",
cp.turn_count, cp.memory_summary
);
if let Some(active_tasks) = cp.active_tasks
&& !active_tasks.trim().is_empty()
{
checkpoint_note.push_str("\nActive tasks: ");
checkpoint_note.push_str(&active_tasks);
}
if let Some(digest) = cp.conversation_digest
&& !digest.trim().is_empty()
{
checkpoint_note.push_str("\nConversation digest: ");
checkpoint_note.push_str(&digest);
}
messages.push(roboticus_llm::format::UnifiedMessage {
role: "system".into(),
content: checkpoint_note,
parts: None,
});
}
Ok(None) => {}
Err(e) => tracing::warn!(error = %e, "failed to load context checkpoint"),
}
match roboticus_db::hippocampus::compact_summary(&state.db) {
Ok(summary) if !summary.is_empty() => {
messages.push(roboticus_llm::format::UnifiedMessage {
role: "system".into(),
content: summary,
parts: None,
});
}
Err(e) => {
tracing::warn!(error = %e, "Failed to generate hippocampus summary");
}
_ => {}
}
if input.inject_diagnostics {
let runtime_diag = collect_runtime_diagnostics(state).await;
messages.push(roboticus_llm::format::UnifiedMessage {
role: "system".into(),
content: diagnostics_system_note(&runtime_diag),
parts: None,
});
}
if let Some(ref note) = input.behavioral_note {
messages.push(roboticus_llm::format::UnifiedMessage {
role: "system".into(),
content: note.clone(),
parts: None,
});
}
if let Some(ref note) = input.gate_system_note {
messages.push(roboticus_llm::format::UnifiedMessage {
role: "system".into(),
content: note.clone(),
parts: None,
});
}
if let Some(ref note) = input.delegated_execution_note {
messages.push(roboticus_llm::format::UnifiedMessage {
role: "system".into(),
content: note.clone(),
parts: None,
});
}
if messages
.last()
.is_none_or(|m| m.content != input.user_content)
{
let parts = input.content_parts.as_ref().map(|cp| {
let mut parts = vec![roboticus_llm::format::ContentPart::Text {
text: input.user_content.to_string(),
}];
parts.extend(cp.iter().cloned());
parts
});
messages.push(roboticus_llm::format::UnifiedMessage {
role: "user".into(),
content: input.user_content.to_string(),
parts,
});
}
if let Some(reminder) =
roboticus_agent::prompt::build_instruction_reminder(&input.os_text, &input.firmware_text)
{
roboticus_agent::context::inject_instruction_reminder(&mut messages, &reminder);
}
{
let cfg = input.state.config.read().await;
if cfg.cache.prompt_compression {
roboticus_agent::context::compress_context(
&mut messages,
cfg.cache.compression_target_ratio,
);
}
}
roboticus_llm::tier::adapt_for_tier(tier, &mut messages, &input.tier_adapt);
let mut context_snapshot =
roboticus_agent::context::classify_context_snapshot(&messages, memories.is_empty());
context_snapshot.token_budget =
roboticus_agent::context::token_budget_with_config(complexity_level, &context_budget_cfg);
let complexity_level_label = match complexity_level {
roboticus_agent::context::ComplexityLevel::L0 => "L0",
roboticus_agent::context::ComplexityLevel::L1 => "L1",
roboticus_agent::context::ComplexityLevel::L2 => "L2",
roboticus_agent::context::ComplexityLevel::L3 => "L3",
};
let request_tools = tools;
let request = roboticus_llm::format::UnifiedRequest {
model: model_for_api.clone(),
messages,
max_tokens: Some(2048),
temperature: None,
system: None,
quality_target: None,
tools: request_tools,
};
let intents = IntentRegistry::default_registry()
.classify_semantic(
input.user_content,
&state.semantic_classifier,
super::super::intent_registry::INTENT_THRESHOLD,
)
.await;
let system_prompt_hash = {
use std::hash::{Hash, Hasher};
let mut hasher = std::hash::DefaultHasher::new();
system_prompt.hash(&mut hasher);
format!("{:016x}", hasher.finish())
};
Ok(PreparedInference {
model,
request,
previous_assistant,
query_embedding,
cache_hash,
system_prompt_hash,
intents,
context_snapshot: PreparedContextSnapshot {
complexity_level: complexity_level_label.to_string(),
token_budget: context_snapshot.token_budget as i64,
system_prompt_tokens: context_snapshot.system_prompt_tokens as i64,
memory_tokens: context_snapshot.memory_tokens as i64,
history_tokens: context_snapshot.history_tokens as i64,
history_depth: context_snapshot.history_depth as i64,
memory_tiers_json: serde_json::to_string(&retrieval_metrics.tiers).ok(),
retrieved_memories_json: serde_json::to_string(&serde_json::json!({
"retrieval_count": retrieval_metrics.retrieval_count,
"retrieval_hit": retrieval_metrics.retrieval_hit,
"avg_similarity": retrieval_metrics.avg_similarity,
"budget_utilization": retrieval_metrics.budget_utilization,
}))
.ok(),
},
retrieval_metrics,
tool_search_stats,
delegated_execution_result: input.delegated_execution_result.clone(),
})
}