use std::sync::Arc;
use nostr_sdk::prelude::{Client, Keys};
use rusqlite::params;
use tokio::sync::Mutex as AsyncMutex;
use tracing::{debug, info, instrument, warn};
use crate::db;
use crate::error::Result;
use crate::models::dispute::InitiatorRole;
use crate::models::mediation::{EscalationTrigger, MediationSessionState};
use crate::models::reasoning::{ClassificationRequest, ReasoningContext};
use crate::models::SolverConfig;
use crate::prompts::PromptBundle;
use crate::reasoning::ReasoningProvider;
use super::{
deliver_summary, draft_and_send_followup_message, escalation, notify_solvers_escalation,
policy, transcript, SessionKeyCache,
};
const TRANSCRIPT_CAP: usize = 40;
const CONSECUTIVE_FAILURE_ESCALATION_THRESHOLD: i64 = 3;
#[instrument(skip_all, fields(session_id = %session_id))]
#[allow(clippy::too_many_arguments)]
pub async fn advance_session_round(
conn: &Arc<AsyncMutex<rusqlite::Connection>>,
client: &Client,
serbero_keys: &Keys,
reasoning: &dyn ReasoningProvider,
prompt_bundle: &Arc<PromptBundle>,
session_id: &str,
session_key_cache: &SessionKeyCache,
solvers: &[SolverConfig],
provider_name: &str,
model_name: &str,
) -> Result<()> {
let info = match load_session_info(conn, session_id).await? {
Some(i) => i,
None => {
debug!("advance_session_round: session row not found; skipping");
return Ok(());
}
};
if !matches!(info.state, MediationSessionState::AwaitingResponse) {
debug!(
state = %info.state,
"advance_session_round: session not in awaiting_response; skipping"
);
return Ok(());
}
let total_fresh_inbounds = {
let guard = conn.lock().await;
db::mediation::count_fresh_inbounds(&guard, session_id)?
};
if total_fresh_inbounds <= info.round_count_last_evaluated {
debug!(
total_fresh_inbounds,
round_count_last_evaluated = info.round_count_last_evaluated,
"advance_session_round: no new fresh inbounds since last evaluation; skipping"
);
return Ok(());
}
let material = {
let cache = session_key_cache.lock().await;
cache.get(session_id).cloned()
};
let Some(material) = material else {
debug!("advance_session_round: no chat material in cache (post-restart?); skipping");
return Ok(());
};
let initiator_role = match load_initiator_role(conn, &info.dispute_id).await? {
Some(r) => r,
None => {
warn!(
dispute_id = %info.dispute_id,
"advance_session_round: dispute row vanished; skipping"
);
return Ok(());
}
};
let transcript_entries = {
let guard = conn.lock().await;
transcript::load_transcript_for_session(&guard, session_id, TRANSCRIPT_CAP)?
};
let classification_req = ClassificationRequest {
session_id: session_id.to_string(),
dispute_id: info.dispute_id.clone(),
initiator_role,
prompt_bundle: Arc::clone(prompt_bundle),
transcript: transcript_entries.clone(),
context: ReasoningContext {
round_count: info.round_count.max(0) as u32,
last_classification: None,
last_confidence: None,
},
};
let classification = match reasoning.classify(classification_req).await {
Ok(c) => c,
Err(e) => {
warn!(error = %e, "advance_session_round: reasoning.classify failed");
handle_reasoning_failure(
conn,
client,
session_id,
&info.dispute_id,
solvers,
prompt_bundle,
)
.await;
return Ok(());
}
};
let followup_number = {
let guard = conn.lock().await;
db::mediation::count_classification_events(&guard, session_id)?
};
let decision = match policy::evaluate(
conn,
session_id,
prompt_bundle,
provider_name,
model_name,
classification,
followup_number,
)
.await
{
Ok(d) => d,
Err(e) => {
warn!(error = %e, "advance_session_round: policy::evaluate failed");
handle_reasoning_failure(
conn,
client,
session_id,
&info.dispute_id,
solvers,
prompt_bundle,
)
.await;
return Ok(());
}
};
match decision {
policy::PolicyDecision::AskClarification {
buyer_text,
seller_text,
} => {
let new_marker = total_fresh_inbounds;
let round_number = followup_number;
if let Err(e) = draft_and_send_followup_message(
conn,
client,
serbero_keys,
session_id,
round_number,
new_marker,
&material.buyer_shared_keys,
&material.seller_shared_keys,
prompt_bundle,
&buyer_text,
&seller_text,
)
.await
{
warn!(
error = %e,
"advance_session_round: follow-up drafter failed; rows may be committed without publish"
);
handle_reasoning_failure(
conn,
client,
session_id,
&info.dispute_id,
solvers,
prompt_bundle,
)
.await;
return Ok(());
}
info!(
round = round_number,
round_count_marked = new_marker,
"advance_session_round: AskClarification dispatched"
);
}
policy::PolicyDecision::Summarize {
classification,
confidence,
} => {
{
let guard = conn.lock().await;
db::mediation::set_session_state(
&guard,
session_id,
MediationSessionState::Classified,
super::current_ts_secs()?,
)?;
}
if let Err(e) = deliver_summary(
conn,
client,
serbero_keys,
session_id,
&info.dispute_id,
classification,
confidence,
transcript_entries,
prompt_bundle,
reasoning,
solvers,
provider_name,
model_name,
)
.await
{
warn!(error = %e, "advance_session_round: deliver_summary failed");
handle_reasoning_failure(
conn,
client,
session_id,
&info.dispute_id,
solvers,
prompt_bundle,
)
.await;
return Ok(());
}
let new_marker = total_fresh_inbounds;
let mut guard = conn.lock().await;
let tx = guard.transaction()?;
db::mediation::advance_evaluator_marker(&tx, session_id, new_marker)?;
tx.commit()?;
info!(
round_count_marked = new_marker,
"advance_session_round: Summarize dispatched"
);
}
policy::PolicyDecision::Escalate(trigger) => {
if let Err(e) = escalation::recommend(escalation::RecommendParams {
conn,
session_id: Some(session_id),
dispute_id: &info.dispute_id,
trigger,
evidence_refs: Vec::new(),
rationale_refs: Vec::new(),
prompt_bundle_id: &prompt_bundle.id,
policy_hash: &prompt_bundle.policy_hash,
})
.await
{
warn!(
error = %e,
trigger = %trigger,
"advance_session_round: escalation::recommend failed"
);
handle_reasoning_failure(
conn,
client,
session_id,
&info.dispute_id,
solvers,
prompt_bundle,
)
.await;
return Ok(());
}
notify_solvers_escalation(conn, client, solvers, &info.dispute_id, session_id, trigger)
.await;
info!(
trigger = %trigger,
"advance_session_round: Escalate dispatched"
);
}
}
Ok(())
}
struct SessionEvalInfo {
state: MediationSessionState,
round_count: i64,
round_count_last_evaluated: i64,
dispute_id: String,
}
async fn load_session_info(
conn: &Arc<AsyncMutex<rusqlite::Connection>>,
session_id: &str,
) -> Result<Option<SessionEvalInfo>> {
use std::str::FromStr;
let guard = conn.lock().await;
let row = guard.query_row(
"SELECT state, round_count, round_count_last_evaluated, dispute_id
FROM mediation_sessions
WHERE session_id = ?1",
params![session_id],
|r| {
Ok((
r.get::<_, String>(0)?,
r.get::<_, i64>(1)?,
r.get::<_, i64>(2)?,
r.get::<_, String>(3)?,
))
},
);
match row {
Ok((state_s, round_count, rcle, dispute_id)) => {
let state = MediationSessionState::from_str(&state_s)?;
Ok(Some(SessionEvalInfo {
state,
round_count,
round_count_last_evaluated: rcle,
dispute_id,
}))
}
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e.into()),
}
}
async fn load_initiator_role(
conn: &Arc<AsyncMutex<rusqlite::Connection>>,
dispute_id: &str,
) -> Result<Option<InitiatorRole>> {
use std::str::FromStr;
let guard = conn.lock().await;
let s: Option<String> = match guard.query_row(
"SELECT initiator_role FROM disputes WHERE dispute_id = ?1",
params![dispute_id],
|r| r.get::<_, String>(0),
) {
Ok(s) => Some(s),
Err(rusqlite::Error::QueryReturnedNoRows) => None,
Err(e) => return Err(e.into()),
};
match s {
Some(s) => Ok(Some(InitiatorRole::from_str(&s)?)),
None => Ok(None),
}
}
async fn handle_reasoning_failure(
conn: &Arc<AsyncMutex<rusqlite::Connection>>,
client: &Client,
session_id: &str,
dispute_id: &str,
solvers: &[SolverConfig],
prompt_bundle: &Arc<PromptBundle>,
) {
let failures = {
let guard = conn.lock().await;
match db::mediation::bump_consecutive_eval_failures(&guard, session_id) {
Ok(n) => n,
Err(e) => {
warn!(error = %e, "advance_session_round: failed to bump failure counter");
return;
}
}
};
if failures < CONSECUTIVE_FAILURE_ESCALATION_THRESHOLD {
warn!(
failures,
threshold = CONSECUTIVE_FAILURE_ESCALATION_THRESHOLD,
"advance_session_round: will retry on next tick"
);
return;
}
warn!(
failures,
"advance_session_round: consecutive failure threshold reached; escalating"
);
if let Err(e) = escalation::recommend(escalation::RecommendParams {
conn,
session_id: Some(session_id),
dispute_id,
trigger: EscalationTrigger::ReasoningUnavailable,
evidence_refs: Vec::new(),
rationale_refs: Vec::new(),
prompt_bundle_id: &prompt_bundle.id,
policy_hash: &prompt_bundle.policy_hash,
})
.await
{
warn!(
error = %e,
"advance_session_round: escalation::recommend also failed after reasoning failures"
);
return;
}
notify_solvers_escalation(
conn,
client,
solvers,
dispute_id,
session_id,
EscalationTrigger::ReasoningUnavailable,
)
.await;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::migrations::run_migrations;
use crate::db::open_in_memory;
use crate::mediation::auth_retry::AuthRetryHandle;
use crate::models::mediation::TranscriptParty;
use crate::models::reasoning::{
ClassificationResponse, ReasoningError, SummaryRequest, SummaryResponse,
};
use crate::prompts::PromptBundle;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
fn test_bundle() -> Arc<PromptBundle> {
Arc::new(PromptBundle {
id: "phase3-default".into(),
policy_hash: "hash-test".into(),
system: String::new(),
classification: String::new(),
escalation: String::new(),
mediation_style: String::new(),
message_templates: String::new(),
})
}
struct SpyClassifier {
calls: AtomicUsize,
}
#[async_trait]
impl ReasoningProvider for SpyClassifier {
async fn classify(
&self,
_request: ClassificationRequest,
) -> std::result::Result<ClassificationResponse, ReasoningError> {
self.calls.fetch_add(1, Ordering::SeqCst);
Err(ReasoningError::Unreachable("should not be called".into()))
}
async fn summarize(
&self,
_request: SummaryRequest,
) -> std::result::Result<SummaryResponse, ReasoningError> {
panic!("summarize unused in follow_up tests")
}
async fn health_check(&self) -> std::result::Result<(), ReasoningError> {
Ok(())
}
}
async fn seeded_db() -> Arc<AsyncMutex<rusqlite::Connection>> {
let mut conn = open_in_memory().unwrap();
run_migrations(&mut conn).unwrap();
conn.execute(
"INSERT INTO disputes (
dispute_id, event_id, mostro_pubkey, initiator_role,
dispute_status, event_timestamp, detected_at, lifecycle_state
) VALUES ('d-t120', 'e-t120', 'm', 'buyer',
'initiated', 1, 2, 'notified')",
[],
)
.unwrap();
Arc::new(AsyncMutex::new(conn))
}
async fn seed_session(
conn: &Arc<AsyncMutex<rusqlite::Connection>>,
state: &str,
round_count: i64,
round_count_last_evaluated: i64,
) {
let guard = conn.lock().await;
guard
.execute(
"INSERT INTO mediation_sessions (
session_id, dispute_id, state, round_count,
round_count_last_evaluated, consecutive_eval_failures,
prompt_bundle_id, policy_hash,
started_at, last_transition_at
) VALUES ('sess-t120', 'd-t120', ?1, ?2, ?3, 0,
'phase3-default', 'hash-test',
100, 100)",
params![state, round_count, round_count_last_evaluated],
)
.unwrap();
}
async fn seed_fresh_inbound(
conn: &Arc<AsyncMutex<rusqlite::Connection>>,
party: TranscriptParty,
inner_event_created_at: i64,
) {
let guard = conn.lock().await;
let party_s = match party {
TranscriptParty::Buyer => "buyer",
TranscriptParty::Seller => "seller",
TranscriptParty::Serbero => {
panic!("Serbero is outbound-only; not valid for inbound seed")
}
};
guard
.execute(
"INSERT INTO mediation_messages (
session_id, direction, party, shared_pubkey,
inner_event_id, inner_event_created_at,
outer_event_id, content,
prompt_bundle_id, policy_hash,
persisted_at, stale
) VALUES ('sess-t120', 'inbound', ?1, 'sp-test',
?2, ?3, NULL, 'hello',
'phase3-default', 'hash-test',
200, 0)",
params![
party_s,
format!("inner-{}", inner_event_created_at),
inner_event_created_at
],
)
.unwrap();
}
async fn run_once(
conn: &Arc<AsyncMutex<rusqlite::Connection>>,
reasoning: &dyn ReasoningProvider,
) {
let serbero_keys = Keys::generate();
let client = Client::new(serbero_keys.clone());
let bundle = test_bundle();
let cache: SessionKeyCache = Arc::new(AsyncMutex::new(HashMap::new()));
let _auth = AuthRetryHandle::new_authorized();
advance_session_round(
conn,
&client,
&serbero_keys,
reasoning,
&bundle,
"sess-t120",
&cache,
&[],
"mock-provider",
"mock-model",
)
.await
.unwrap();
}
#[tokio::test]
async fn skips_when_session_row_is_absent() {
let mut conn = open_in_memory().unwrap();
run_migrations(&mut conn).unwrap();
let conn = Arc::new(AsyncMutex::new(conn));
let spy = SpyClassifier {
calls: AtomicUsize::new(0),
};
run_once(&conn, &spy).await;
assert_eq!(spy.calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn skips_when_state_is_not_awaiting_response() {
let conn = seeded_db().await;
seed_session(&conn, "escalation_recommended", 3, 2).await;
let spy = SpyClassifier {
calls: AtomicUsize::new(0),
};
run_once(&conn, &spy).await;
assert_eq!(
spy.calls.load(Ordering::SeqCst),
0,
"state gate must block classify when session is not awaiting_response"
);
}
#[tokio::test]
async fn skips_when_fresh_inbounds_already_evaluated() {
let conn = seeded_db().await;
seed_session(&conn, "awaiting_response", 1, 2).await;
seed_fresh_inbound(&conn, TranscriptParty::Buyer, 10).await;
seed_fresh_inbound(&conn, TranscriptParty::Seller, 20).await;
let spy = SpyClassifier {
calls: AtomicUsize::new(0),
};
run_once(&conn, &spy).await;
assert_eq!(
spy.calls.load(Ordering::SeqCst),
0,
"gate must block when total fresh inbounds <= round_count_last_evaluated"
);
}
#[tokio::test]
async fn skips_when_cache_material_missing() {
let conn = seeded_db().await;
seed_session(&conn, "awaiting_response", 3, 2).await;
let spy = SpyClassifier {
calls: AtomicUsize::new(0),
};
run_once(&conn, &spy).await;
assert_eq!(
spy.calls.load(Ordering::SeqCst),
0,
"missing-cache gate must block classify when material is absent"
);
}
}