use std::time::Duration;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, error, warn};
use super::ReasoningProvider;
use crate::error::Result;
use crate::models::mediation::{ClassificationLabel, Flag};
use crate::models::reasoning::{
ClassificationRequest, ClassificationResponse, EscalationReason, RationaleText, ReasoningError,
SuggestedAction, SummaryRequest, SummaryResponse,
};
use crate::models::ReasoningConfig;
pub struct OpenAiProvider {
http: Client,
api_base: String,
api_key: String,
model: String,
timeout: Duration,
retries: u32,
}
impl OpenAiProvider {
pub fn new(config: &ReasoningConfig) -> Result<Self> {
let timeout = Duration::from_secs(config.request_timeout_seconds.max(1));
let http = Client::builder()
.timeout(timeout)
.build()
.map_err(|e| crate::error::Error::Config(format!("reqwest build failed: {e}")))?;
Ok(Self {
http,
api_base: config.api_base.trim_end_matches('/').to_string(),
api_key: config.api_key.clone(),
model: config.model.clone(),
timeout,
retries: config.followup_retry_count,
})
}
fn chat_completions_url(&self) -> String {
format!("{}/chat/completions", self.api_base)
}
fn request_temperature(&self, value: f64) -> Option<f64> {
if self.model.starts_with("gpt-5") {
None
} else {
Some(value)
}
}
}
#[derive(Serialize)]
struct ChatMessage<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<ChatMessage<'a>>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
}
#[derive(Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
kind: String,
}
#[derive(Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Deserialize)]
struct ResponseMessage {
content: Option<String>,
}
#[derive(Deserialize)]
struct ClassificationJson {
classification: String,
confidence: f64,
#[serde(default)]
suggested_action: String,
#[serde(default)]
suggested_action_detail: Option<String>,
#[serde(default)]
buyer_clarification: Option<String>,
#[serde(default)]
seller_clarification: Option<String>,
#[serde(default)]
rationale: String,
#[serde(default)]
flags: Vec<String>,
}
#[async_trait]
impl ReasoningProvider for OpenAiProvider {
async fn classify(
&self,
request: ClassificationRequest,
) -> std::result::Result<ClassificationResponse, ReasoningError> {
let system = request.prompt_bundle.system.clone();
let prompt = build_classification_prompt(&request);
let body = ChatRequest {
model: &self.model,
messages: vec![
ChatMessage {
role: "system",
content: &system,
},
ChatMessage {
role: "user",
content: &prompt,
},
],
response_format: Some(ResponseFormat {
kind: "json_object".into(),
}),
temperature: self.request_temperature(0.0),
};
let raw = self.post_chat(&body).await?;
parse_classification(&raw)
}
async fn summarize(
&self,
request: SummaryRequest,
) -> std::result::Result<SummaryResponse, ReasoningError> {
let system = request.prompt_bundle.system.clone();
let prompt = build_summary_prompt(&request);
let body = ChatRequest {
model: &self.model,
messages: vec![
ChatMessage {
role: "system",
content: &system,
},
ChatMessage {
role: "user",
content: &prompt,
},
],
response_format: None,
temperature: self.request_temperature(0.2),
};
let raw = self.post_chat(&body).await?;
parse_summary(&raw)
}
async fn health_check(&self) -> std::result::Result<(), ReasoningError> {
let body = ChatRequest {
model: &self.model,
messages: vec![ChatMessage {
role: "user",
content: "ping",
}],
response_format: None,
temperature: self.request_temperature(0.0),
};
self.post_chat(&body).await.map(|_| ())
}
}
impl OpenAiProvider {
async fn post_chat(
&self,
body: &ChatRequest<'_>,
) -> std::result::Result<String, ReasoningError> {
let url = self.chat_completions_url();
let mut last_err: Option<ReasoningError> = None;
let total_attempts = self.retries.saturating_add(1);
for attempt in 0..total_attempts {
debug!(
attempt,
api_base = self.api_base,
model = self.model,
"openai reasoning call"
);
let resp = self
.http
.post(&url)
.bearer_auth(&self.api_key)
.json(body)
.timeout(self.timeout)
.send()
.await;
let resp = match resp {
Ok(r) => r,
Err(e) if e.is_timeout() => {
last_err = Some(ReasoningError::Timeout);
warn!(attempt, "openai request timed out");
continue;
}
Err(e) => {
last_err = Some(ReasoningError::Unreachable(e.to_string()));
warn!(attempt, error = %e, "openai request failed");
continue;
}
};
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let err =
ReasoningError::Unreachable(format!("http {status}: {}", truncate(&body, 200)));
let retryable =
status.as_u16() == 408 || status.as_u16() == 429 || status.is_server_error();
if retryable {
last_err = Some(err);
warn!(attempt, %status, "openai returned retryable status");
continue;
} else {
error!(%status, "openai returned non-retryable status; failing fast");
return Err(err);
}
}
let text = resp
.text()
.await
.map_err(|e| ReasoningError::MalformedResponse(e.to_string()))?;
let parsed: ChatResponse = serde_json::from_str(&text).map_err(|e| {
ReasoningError::MalformedResponse(format!("{e}: body={}", truncate(&text, 200)))
})?;
let content = parsed
.choices
.into_iter()
.next()
.and_then(|c| c.message.content)
.ok_or_else(|| ReasoningError::MalformedResponse("empty choices".into()))?;
use nostr_sdk::hashes::Hash as _;
let content_hash = nostr_sdk::hashes::sha256::Hash::hash(content.as_bytes());
let content_hash_prefix = &content_hash.to_string()[..16];
debug!(
attempt,
model = self.model,
content_len = content.len(),
content_sha256_prefix = content_hash_prefix,
"openai reasoning call response"
);
return Ok(content);
}
Err(last_err.unwrap_or(ReasoningError::Unreachable("exhausted retries".into())))
}
}
fn build_classification_prompt(r: &ClassificationRequest) -> String {
let transcript = r
.transcript
.iter()
.map(|e| format!("[{}] {}: {}", e.inner_event_created_at, e.party, e.content))
.collect::<Vec<_>>()
.join("\n");
format!(
"## Session metadata\n\
session_id: {sid}\n\
dispute_id: {did}\n\
initiator: {init}\n\
prompt_bundle_id: {bid}\n\
policy_hash: {ph}\n\
round_count: {rc}\n\n\
## Classification policy (from bundle)\n{cls}\n\n\
## Escalation policy (from bundle)\n{esc}\n\n\
## Mediation style (from bundle)\n{sty}\n\n\
## Message templates (from bundle)\n{tpl}\n\n\
## Transcript\n{tr}\n\n\
## Output contract\n\
Return JSON with keys: classification (one of coordination_failure_resolvable, \
conflicting_claims, suspected_fraud, unclear, not_suitable_for_mediation), \
confidence (0..1), suggested_action (ask_clarification|summarize|escalate), \
rationale (string), flags (array of fraud_risk|conflicting_claims|low_info|\
unresponsive_party|authority_boundary_attempt).\n\
When suggested_action = ask_clarification you MUST also return \
buyer_clarification (string, addressed to the buyer, asking what you need \
from the buyer to advance the case) and seller_clarification (string, \
addressed to the seller, asking what you need from the seller). Each text \
goes only to its intended party — do NOT prefix with labels like \
\"Buyer:\" or \"Seller:\"; the transport layer handles routing. Tailor \
each question to that party's role (buyer = did you send fiat? proof; \
seller = did you receive fiat? proof). Both strings must be non-empty; \
if you cannot produce a useful question for one side, pick a different \
suggested_action (summarize or escalate). suggested_action_detail is \
optional and only used to carry the escalation reason when \
suggested_action = escalate.",
sid = r.session_id,
did = r.dispute_id,
init = r.initiator_role,
bid = r.prompt_bundle.id,
ph = r.prompt_bundle.policy_hash,
rc = r.context.round_count,
cls = r.prompt_bundle.classification,
esc = r.prompt_bundle.escalation,
sty = r.prompt_bundle.mediation_style,
tpl = r.prompt_bundle.message_templates,
tr = transcript,
)
}
fn build_summary_prompt(r: &SummaryRequest) -> String {
let transcript = r
.transcript
.iter()
.map(|e| format!("[{}] {}: {}", e.inner_event_created_at, e.party, e.content))
.collect::<Vec<_>>()
.join("\n");
format!(
"## Session metadata\n\
session_id: {sid}\n\
dispute_id: {did}\n\
prompt_bundle_id: {bid}\n\
policy_hash: {ph}\n\
classification: {cls}\n\
confidence: {cf}\n\n\
## Mediation style (from bundle)\n{sty}\n\n\
## Message templates (from bundle)\n{tpl}\n\n\
## Escalation policy (from bundle, for reference)\n{esc}\n\n\
## Transcript\n{tr}\n\n\
## Output contract\n\
Produce a short summary for the assigned solver, followed by a single-line \
SUGGESTED_NEXT_STEP: line. Do NOT suggest fund actions. Do NOT claim final \
authority. End with a RATIONALE: line.",
sid = r.session_id,
did = r.dispute_id,
bid = r.prompt_bundle.id,
ph = r.prompt_bundle.policy_hash,
cls = r.classification,
cf = r.confidence,
sty = r.prompt_bundle.mediation_style,
tpl = r.prompt_bundle.message_templates,
esc = r.prompt_bundle.escalation,
tr = transcript,
)
}
fn parse_classification(raw: &str) -> std::result::Result<ClassificationResponse, ReasoningError> {
let parsed: ClassificationJson = serde_json::from_str(raw).map_err(|e| {
ReasoningError::MalformedResponse(format!("{e}: body={}", truncate(raw, 200)))
})?;
let classification = match parsed.classification.as_str() {
"coordination_failure_resolvable" => ClassificationLabel::CoordinationFailureResolvable,
"conflicting_claims" => ClassificationLabel::ConflictingClaims,
"suspected_fraud" => ClassificationLabel::SuspectedFraud,
"unclear" => ClassificationLabel::Unclear,
"not_suitable_for_mediation" => ClassificationLabel::NotSuitableForMediation,
other => {
return Err(ReasoningError::MalformedResponse(format!(
"unknown classification label: {other}"
)))
}
};
let suggested_action = match parsed.suggested_action.as_str() {
"ask_clarification" => {
let buyer_text = parsed
.buyer_clarification
.as_deref()
.map(str::trim)
.filter(|s| !s.is_empty())
.ok_or_else(|| {
ReasoningError::MalformedResponse(
"ask_clarification requires non-empty buyer_clarification".into(),
)
})?
.to_string();
let seller_text = parsed
.seller_clarification
.as_deref()
.map(str::trim)
.filter(|s| !s.is_empty())
.ok_or_else(|| {
ReasoningError::MalformedResponse(
"ask_clarification requires non-empty seller_clarification".into(),
)
})?
.to_string();
SuggestedAction::AskClarification {
buyer_text,
seller_text,
}
}
"summarize" => SuggestedAction::Summarize,
"escalate" => SuggestedAction::Escalate(EscalationReason(
parsed.suggested_action_detail.clone().unwrap_or_default(),
)),
other => {
return Err(ReasoningError::MalformedResponse(format!(
"unknown suggested_action: {other}"
)))
}
};
let flags: Vec<Flag> = parsed
.flags
.into_iter()
.map(|f| match f.as_str() {
"fraud_risk" => Ok(Flag::FraudRisk),
"conflicting_claims" => Ok(Flag::ConflictingClaims),
"low_info" => Ok(Flag::LowInfo),
"unresponsive_party" => Ok(Flag::UnresponsiveParty),
"authority_boundary_attempt" => Ok(Flag::AuthorityBoundaryAttempt),
other => Err(ReasoningError::MalformedResponse(format!(
"unknown flag: {other}"
))),
})
.collect::<std::result::Result<_, _>>()?;
Ok(ClassificationResponse {
classification,
confidence: parsed.confidence.clamp(0.0, 1.0),
suggested_action,
rationale: RationaleText(parsed.rationale),
flags,
})
}
fn parse_summary(raw: &str) -> std::result::Result<SummaryResponse, ReasoningError> {
const NEXT_MARKER: &str = "SUGGESTED_NEXT_STEP:";
const RATIONALE_MARKER: &str = "RATIONALE:";
let next_idx = raw.find(NEXT_MARKER);
let rationale_idx = raw.find(RATIONALE_MARKER);
if let (Some(n), Some(r)) = (next_idx, rationale_idx) {
if r < n {
return Err(ReasoningError::MalformedResponse(
"summary markers out of order: RATIONALE: appeared before \
SUGGESTED_NEXT_STEP:"
.into(),
));
}
}
let (summary_text, suggested_next_step, rationale_text) = match (next_idx, rationale_idx) {
(Some(n), Some(r)) => {
let summary = raw[..n].trim().to_string();
let next = raw[n + NEXT_MARKER.len()..r].trim().to_string();
let rationale = raw[r + RATIONALE_MARKER.len()..].trim().to_string();
(summary, next, rationale)
}
(Some(n), None) => {
let summary = raw[..n].trim().to_string();
let next = raw[n + NEXT_MARKER.len()..].trim().to_string();
(summary, next, String::new())
}
(None, Some(r)) => {
let summary = raw[..r].trim().to_string();
let rationale = raw[r + RATIONALE_MARKER.len()..].trim().to_string();
(summary, String::new(), rationale)
}
(None, None) => (raw.trim().to_string(), String::new(), String::new()),
};
if summary_text.is_empty() {
return Err(ReasoningError::MalformedResponse(
"empty summary body".into(),
));
}
Ok(SummaryResponse {
summary_text,
suggested_next_step,
rationale: RationaleText(rationale_text),
})
}
fn truncate(s: &str, n: usize) -> &str {
if s.len() <= n {
return s;
}
let mut end = 0;
for (idx, ch) in s.char_indices() {
let next = idx + ch.len_utf8();
if next > n {
break;
}
end = next;
}
&s[..end]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_classification_happy_path() {
let raw = r#"{
"classification":"coordination_failure_resolvable",
"confidence":0.91,
"suggested_action":"summarize",
"rationale":"parties agreed on payment timing",
"flags":["low_info"]
}"#;
let parsed = parse_classification(raw).unwrap();
assert_eq!(
parsed.classification,
ClassificationLabel::CoordinationFailureResolvable
);
assert!((parsed.confidence - 0.91).abs() < f64::EPSILON);
assert_eq!(parsed.suggested_action, SuggestedAction::Summarize);
assert_eq!(parsed.flags, vec![Flag::LowInfo]);
}
#[test]
fn parse_classification_ask_clarification_happy_path() {
let raw = r#"{
"classification":"unclear",
"confidence":0.42,
"suggested_action":"ask_clarification",
"buyer_clarification":"Buyer, have you sent the fiat?",
"seller_clarification":"Seller, have you received the fiat?",
"rationale":"need more info from both sides",
"flags":[]
}"#;
let parsed = parse_classification(raw).unwrap();
match parsed.suggested_action {
SuggestedAction::AskClarification {
buyer_text,
seller_text,
} => {
assert_eq!(buyer_text, "Buyer, have you sent the fiat?");
assert_eq!(seller_text, "Seller, have you received the fiat?");
}
other => panic!("expected AskClarification, got {other:?}"),
}
}
#[test]
fn parse_classification_ask_clarification_rejects_blank_buyer_text() {
let raw = r#"{
"classification":"unclear",
"confidence":0.6,
"suggested_action":"ask_clarification",
"buyer_clarification":" \n\t",
"seller_clarification":"Seller, have you received the fiat?",
"rationale":"r"
}"#;
let err = parse_classification(raw).unwrap_err();
match err {
ReasoningError::MalformedResponse(msg) => {
assert!(
msg.contains("buyer_clarification"),
"error must cite the missing field: {msg}"
);
}
other => panic!("expected MalformedResponse, got {other:?}"),
}
}
#[test]
fn parse_classification_ask_clarification_rejects_missing_seller_text() {
let raw = r#"{
"classification":"unclear",
"confidence":0.6,
"suggested_action":"ask_clarification",
"buyer_clarification":"Buyer, did you send the fiat?",
"rationale":"r"
}"#;
let err = parse_classification(raw).unwrap_err();
match err {
ReasoningError::MalformedResponse(msg) => {
assert!(
msg.contains("seller_clarification"),
"error must cite the missing field: {msg}"
);
}
other => panic!("expected MalformedResponse, got {other:?}"),
}
}
#[test]
fn parse_classification_rejects_unknown_label() {
let raw = r#"{
"classification":"totally_made_up",
"confidence":0.5,
"suggested_action":"summarize",
"rationale":""
}"#;
let err = parse_classification(raw).unwrap_err();
assert!(matches!(err, ReasoningError::MalformedResponse(_)));
}
#[test]
fn parse_summary_happy_path() {
let raw = "Buyer confirmed receipt, seller confirmed funds released.\n\
SUGGESTED_NEXT_STEP: close the dispute in favor of buyer.\n\
RATIONALE: both parties aligned on the timeline.";
let parsed = parse_summary(raw).unwrap();
assert!(parsed.summary_text.starts_with("Buyer"));
assert!(parsed.suggested_next_step.contains("close"));
assert!(parsed.rationale.0.contains("aligned"));
}
#[test]
fn parse_summary_rejects_empty() {
let err = parse_summary("").unwrap_err();
assert!(matches!(err, ReasoningError::MalformedResponse(_)));
}
#[test]
fn parse_summary_rejects_inverted_markers() {
let raw = "the summary body.\n\
RATIONALE: some rationale.\n\
SUGGESTED_NEXT_STEP: too late.";
let err = parse_summary(raw).unwrap_err();
match err {
ReasoningError::MalformedResponse(msg) => {
assert!(
msg.to_lowercase().contains("out of order"),
"expected an out-of-order error: {msg}"
);
}
other => panic!("expected MalformedResponse, got {other:?}"),
}
}
#[test]
fn parse_summary_handles_missing_rationale() {
let raw = "just a summary.\nSUGGESTED_NEXT_STEP: do the thing.";
let parsed = parse_summary(raw).unwrap();
assert_eq!(parsed.summary_text, "just a summary.");
assert_eq!(parsed.suggested_next_step, "do the thing.");
assert_eq!(parsed.rationale.0, "");
}
#[test]
fn parse_summary_handles_missing_next_step() {
let raw = "just a summary.\nRATIONALE: because reasons.";
let parsed = parse_summary(raw).unwrap();
assert_eq!(parsed.summary_text, "just a summary.");
assert_eq!(parsed.suggested_next_step, "");
assert_eq!(parsed.rationale.0, "because reasons.");
}
#[test]
fn parse_classification_rejects_unknown_flag() {
let raw = r#"{
"classification":"coordination_failure_resolvable",
"confidence":0.8,
"suggested_action":"summarize",
"rationale":"",
"flags":["fraud_risk","totally_made_up"]
}"#;
let err = parse_classification(raw).unwrap_err();
assert!(matches!(err, ReasoningError::MalformedResponse(_)));
}
#[test]
fn truncate_respects_utf8_boundaries() {
let s = "héllo";
let got = truncate(s, 2);
assert_eq!(got, "h");
assert_eq!(truncate(s, 3), "hé");
assert_eq!(truncate(s, 100), "héllo");
}
#[test]
fn provider_honors_configured_retry_count() {
for configured in [0u32, 1, 3, 7] {
let cfg = ReasoningConfig {
provider: "openai".into(),
followup_retry_count: configured,
..ReasoningConfig::default()
};
let provider = OpenAiProvider::new(&cfg).unwrap();
assert_eq!(
provider.retries, configured,
"adapter must reflect the configured followup_retry_count"
);
}
}
#[test]
fn credential_is_read_from_api_key_field() {
let cfg = ReasoningConfig {
api_key: "secret-from-env".into(),
..ReasoningConfig::default()
};
let provider = OpenAiProvider::new(&cfg).unwrap();
assert_eq!(provider.api_key, "secret-from-env");
}
#[test]
fn request_url_uses_configured_api_base() {
let cfg = ReasoningConfig {
api_base: "http://localhost:8080/custom/v1".into(),
api_key: "k".into(),
..ReasoningConfig::default()
};
let provider = OpenAiProvider::new(&cfg).unwrap();
assert_eq!(
provider.chat_completions_url(),
"http://localhost:8080/custom/v1/chat/completions"
);
let cfg_slash = ReasoningConfig {
api_base: "http://localhost:8080/custom/v1/".into(),
api_key: "k".into(),
..ReasoningConfig::default()
};
let provider_slash = OpenAiProvider::new(&cfg_slash).unwrap();
assert_eq!(
provider_slash.chat_completions_url(),
"http://localhost:8080/custom/v1/chat/completions",
"trailing slash on api_base must not produce a double slash"
);
}
#[test]
fn request_timeout_is_configured() {
let cfg = ReasoningConfig {
request_timeout_seconds: 42,
api_key: "k".into(),
..ReasoningConfig::default()
};
let provider = OpenAiProvider::new(&cfg).unwrap();
assert_eq!(provider.timeout, Duration::from_secs(42));
let cfg_zero = ReasoningConfig {
request_timeout_seconds: 0,
api_key: "k".into(),
..ReasoningConfig::default()
};
let provider_zero = OpenAiProvider::new(&cfg_zero).unwrap();
assert_eq!(
provider_zero.timeout,
Duration::from_secs(1),
"request_timeout_seconds = 0 must be floored to 1 s"
);
}
use std::sync::Arc;
use crate::models::dispute::InitiatorRole;
use crate::models::reasoning::{ClassificationRequest, ReasoningContext, SummaryRequest};
use crate::prompts::PromptBundle;
fn fixture_bundle() -> Arc<PromptBundle> {
Arc::new(PromptBundle {
id: "phase3-test".to_string(),
policy_hash: "abc123".to_string(),
system: "SYSTEM_MARKER: you are serbero".to_string(),
classification: "CLASSIFICATION_MARKER: policy text".to_string(),
escalation: "ESCALATION_MARKER: escalation rules".to_string(),
mediation_style: "STYLE_MARKER: neutral tone".to_string(),
message_templates: "TEMPLATE_MARKER: templates here".to_string(),
})
}
#[test]
fn classify_prompt_includes_every_bundle_section() {
let req = ClassificationRequest {
session_id: "s1".into(),
dispute_id: "d1".into(),
initiator_role: InitiatorRole::Buyer,
prompt_bundle: fixture_bundle(),
transcript: vec![],
context: ReasoningContext {
round_count: 0,
last_classification: None,
last_confidence: None,
},
};
let user = build_classification_prompt(&req);
for marker in [
"CLASSIFICATION_MARKER",
"ESCALATION_MARKER",
"STYLE_MARKER",
"TEMPLATE_MARKER",
] {
assert!(
user.contains(marker),
"classification user prompt missing `{marker}`:\n{user}"
);
}
assert!(user.contains("policy_hash: abc123"));
assert!(user.contains("prompt_bundle_id: phase3-test"));
}
#[test]
fn summary_prompt_includes_every_relevant_bundle_section() {
let req = SummaryRequest {
session_id: "s1".into(),
dispute_id: "d1".into(),
prompt_bundle: fixture_bundle(),
transcript: vec![],
classification: ClassificationLabel::CoordinationFailureResolvable,
confidence: 0.9,
};
let user = build_summary_prompt(&req);
for marker in ["STYLE_MARKER", "TEMPLATE_MARKER", "ESCALATION_MARKER"] {
assert!(
user.contains(marker),
"summary user prompt missing `{marker}`:\n{user}"
);
}
assert!(user.contains("policy_hash: abc123"));
assert!(user.contains("prompt_bundle_id: phase3-test"));
}
}