use futures_util::StreamExt;
use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use crate::report::{AiAnalysis, Report};
use super::prompts::{analysis_system_prompt, analysis_user_message};
pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
const ANTHROPIC_VERSION: &str = "2023-06-01";
const MAX_TOKENS: u32 = 4096;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Message { role: "user".into(), content: content.into() }
}
pub fn assistant(content: impl Into<String>) -> Self {
Message { role: "assistant".into(), content: content.into() }
}
}
#[derive(Serialize)]
struct ApiRequest<'a> {
model: &'a str,
max_tokens: u32,
system: &'a str,
messages: &'a [Message],
#[serde(skip_serializing_if = "std::ops::Not::not")]
stream: bool,
}
#[derive(Deserialize)]
struct ApiResponse {
content: Vec<ContentBlock>,
}
#[derive(Deserialize)]
struct ContentBlock {
text: String,
}
#[derive(Deserialize)]
struct StreamDelta {
#[serde(rename = "type")]
kind: String,
text: Option<String>,
}
#[derive(Deserialize)]
struct StreamEvent {
#[serde(rename = "type")]
kind: String,
delta: Option<StreamDelta>,
}
pub struct ClaudeClient {
api_key: String,
pub model: String,
client: reqwest::Client,
}
impl ClaudeClient {
pub fn new(api_key: Option<String>, model: Option<String>) -> anyhow::Result<Self> {
let key = api_key
.or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
.ok_or_else(|| anyhow::anyhow!(
"No Anthropic API key found. Set ANTHROPIC_API_KEY environment variable."
))?;
let mut headers = HeaderMap::new();
headers.insert("x-api-key", HeaderValue::from_str(&key)?);
headers.insert("anthropic-version", HeaderValue::from_static(ANTHROPIC_VERSION));
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let client = reqwest::Client::builder()
.default_headers(headers)
.build()?;
Ok(ClaudeClient {
api_key: key,
model: model.unwrap_or_else(|| DEFAULT_MODEL.into()),
client,
})
}
pub async fn analyze_report(&self, report: &Report) -> anyhow::Result<AiAnalysis> {
let system = analysis_system_prompt(&report.domain);
let user_msg = analysis_user_message(report)?;
let body = ApiRequest {
model: &self.model,
max_tokens: MAX_TOKENS,
system: &system,
messages: &[Message::user(&user_msg)],
stream: false,
};
let resp = self.client
.post(ANTHROPIC_API_URL)
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("Claude API error {}: {}", status, text);
}
let api_resp: ApiResponse = resp.json().await?;
let raw_text = api_resp.content.first()
.map(|c| c.text.as_str())
.unwrap_or("")
.trim();
let json_str = raw_text
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
#[derive(Deserialize)]
struct AnalysisJson {
attack_narrative: String,
critical_path: Vec<String>,
immediate_actions: Vec<String>,
}
let parsed: AnalysisJson = serde_json::from_str(json_str)
.map_err(|e| anyhow::anyhow!("Failed to parse Claude response as JSON: {}\nRaw: {}", e, json_str))?;
Ok(AiAnalysis {
model: self.model.clone(),
attack_narrative: parsed.attack_narrative,
critical_path: parsed.critical_path,
immediate_actions: parsed.immediate_actions,
generated_at: chrono::Utc::now(),
})
}
pub async fn stream_message(
&self,
system: &str,
messages: &[Message],
) -> anyhow::Result<String> {
let body = ApiRequest {
model: &self.model,
max_tokens: MAX_TOKENS,
system,
messages,
stream: true,
};
let resp = self.client
.post(ANTHROPIC_API_URL)
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("Claude API error {}: {}", status, text);
}
let mut stream = resp.bytes_stream();
let mut full_text = String::new();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
let bytes = chunk?;
buffer.push_str(&String::from_utf8_lossy(&bytes));
loop {
if let Some(newline_pos) = buffer.find('\n') {
let line = buffer[..newline_pos].trim().to_string();
buffer = buffer[newline_pos + 1..].to_string();
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
break;
}
if let Ok(event) = serde_json::from_str::<StreamEvent>(data) {
if event.kind == "content_block_delta" {
if let Some(delta) = event.delta {
if delta.kind == "text_delta" {
if let Some(text) = delta.text {
print!("{}", text);
use std::io::Write;
std::io::stdout().flush().ok();
full_text.push_str(&text);
}
}
}
}
}
}
} else {
break;
}
}
}
println!(); Ok(full_text)
}
}