use anyhow::{Context, Result};
use reqwest::{Client, StatusCode};
use secrecy::{ExposeSecret, SecretString};
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::timeout;
use tracing::{debug, warn};
use url::Url;
use crate::glm46::types::*;
const TRUSTED_HOSTS: &[&str] = &["openrouter.ai", "api.zhipu.ai", "api.openai.com"];
const BLOCKED_HOSTS: &[&str] = &[
"169.254.169.254", "metadata.google.internal",
"metadata.gcp.internal",
];
pub fn validate_base_url(url: &str, local_fallback: bool) -> Result<()> {
let parsed = Url::parse(url).context("Invalid base URL")?;
let host = parsed
.host_str()
.ok_or_else(|| anyhow::anyhow!("Invalid URL: no host"))?;
if BLOCKED_HOSTS
.iter()
.any(|blocked| host.eq_ignore_ascii_case(blocked))
{
return Err(anyhow::anyhow!(
"SSRF protection: blocked access to cloud metadata endpoint"
));
}
if TRUSTED_HOSTS.iter().any(|trusted| host.ends_with(trusted)) {
return Ok(());
}
let is_localhost = is_localhost_url(&parsed);
let is_private_ip = is_private_ip_url(host);
if !is_localhost && !is_private_ip && parsed.scheme() != "https" {
return Err(anyhow::anyhow!(
"SSRF protection: non-localhost URLs must use HTTPS"
));
}
if let Ok(ip) = host.parse::<IpAddr>() {
return validate_ip_address(ip, local_fallback);
}
if is_localhost {
if !local_fallback {
return Err(anyhow::anyhow!(
"SSRF protection: localhost URLs require local_fallback=true"
));
}
return Ok(());
}
Ok(())
}
fn is_localhost_url(url: &Url) -> bool {
match url.host_str() {
Some(host) => {
host.eq_ignore_ascii_case("localhost")
|| host == "127.0.0.1"
|| host == "::1"
|| host == "[::1]"
}
None => false,
}
}
fn is_private_ip_url(host: &str) -> bool {
if let Ok(ip) = host.parse::<IpAddr>() {
match ip {
IpAddr::V4(ipv4) => ipv4.is_private() || ipv4.is_loopback(),
IpAddr::V6(ipv6) => ipv6.is_loopback(),
}
} else {
false
}
}
fn validate_ip_address(ip: IpAddr, local_fallback: bool) -> Result<()> {
match ip {
IpAddr::V4(ipv4) => {
if ipv4.is_loopback() {
return if local_fallback {
Ok(())
} else {
Err(anyhow::anyhow!(
"SSRF protection: loopback addresses require local_fallback=true"
))
};
}
if ipv4.is_private() && !local_fallback {
return Err(anyhow::anyhow!(
"SSRF protection: private IP addresses require local_fallback=true"
));
}
if ipv4.is_link_local() {
return Err(anyhow::anyhow!(
"SSRF protection: link-local addresses are blocked"
));
}
Ok(())
}
IpAddr::V6(ipv6) => {
if ipv6.is_loopback() {
return if local_fallback {
Ok(())
} else {
Err(anyhow::anyhow!(
"SSRF protection: loopback addresses require local_fallback=true"
))
};
}
Ok(())
}
}
}
#[derive(Clone)]
pub struct GLM46Config {
pub api_key: SecretString,
pub base_url: String,
pub model: String,
pub timeout: Duration,
pub context_budget: usize,
pub cost_tracking: bool,
pub local_fallback: bool,
}
impl std::fmt::Debug for GLM46Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GLM46Config")
.field("api_key", &"[REDACTED]")
.field("base_url", &self.base_url)
.field("model", &self.model)
.field("timeout", &self.timeout)
.field("context_budget", &self.context_budget)
.field("cost_tracking", &self.cost_tracking)
.field("local_fallback", &self.local_fallback)
.finish()
}
}
impl Default for GLM46Config {
fn default() -> Self {
Self {
api_key: SecretString::from(std::env::var("GLM46_API_KEY").unwrap_or_default()),
base_url: std::env::var("GLM46_BASE_URL")
.unwrap_or_else(|_| "https://openrouter.ai/api/v1".to_string()),
model: "glm-4.6".to_string(),
timeout: Duration::from_secs(30),
context_budget: 198_000,
cost_tracking: true,
local_fallback: true,
}
}
}
#[derive(Debug)]
pub struct GLM46Client {
config: GLM46Config,
http_client: reqwest::Client,
cost_tracker: Option<Arc<Mutex<CostTracker>>>,
}
impl GLM46Client {
pub fn new(config: GLM46Config) -> Result<Self> {
validate_base_url(&config.base_url, config.local_fallback)?;
let http_client = Client::builder()
.timeout(config.timeout)
.user_agent("reasonkit-glm46/0.1.0")
.build()?;
let cost_tracker = if config.cost_tracking {
Some(Arc::new(Mutex::new(CostTracker::new())))
} else {
None
};
Ok(Self {
config,
http_client,
cost_tracker,
})
}
pub fn from_env() -> Result<Self> {
let config = GLM46Config::default();
if config.api_key.expose_secret().is_empty() {
return Err(anyhow::anyhow!(
"GLM46_API_KEY environment variable required"
));
}
Self::new(config)
}
pub fn config(&self) -> &GLM46Config {
&self.config
}
pub async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse> {
debug!(
"Executing GLM-4.6 chat completion with {} messages",
request.messages.len()
);
let optimized_request = self.optimize_for_coordination(request);
let api_request = APIRequest::from_chat_request(&optimized_request, &self.config);
let response = timeout(self.config.timeout, self.send_request(&api_request))
.await
.map_err(|_| {
crate::error::Error::Network(format!(
"Request timeout after {:?}",
self.config.timeout
))
})??;
let chat_response = self.parse_response(response).await?;
if let Some(tracker) = &self.cost_tracker {
let mut tracker = tracker.lock().await;
tracker.record_request(&chat_response, &optimized_request)?;
}
Ok(chat_response)
}
pub async fn stream_chat_completion(
&self,
request: ChatRequest,
) -> Result<tokio::sync::mpsc::UnboundedReceiver<StreamChunk>> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let config = self.config.clone();
let http_client = self.http_client.clone();
tokio::spawn(async move {
let temp_client = GLM46Client {
config,
http_client,
cost_tracker: None, };
let optimized_request = temp_client.optimize_for_coordination(request);
let api_request =
APIRequest::from_chat_request_stream(&optimized_request, &temp_client.config);
match temp_client.send_stream_request(api_request, tx).await {
Ok(_) => debug!("Stream completed successfully"),
Err(e) => warn!("Stream error: {:?}", e),
}
});
Ok(rx)
}
pub async fn get_usage_stats(&self) -> Option<UsageStats> {
if let Some(tracker) = &self.cost_tracker {
Some(tracker.lock().await.get_stats())
} else {
None
}
}
pub async fn reset_stats(&self) {
if let Some(tracker) = &self.cost_tracker {
tracker.lock().await.reset();
}
}
pub async fn health_check(&self) -> Result<HealthStatus> {
let test_request = APIRequest {
model: self.config.model.clone(),
messages: vec![ChatMessage::system("ping")],
temperature: 0.1,
max_tokens: 10,
stream: false,
stop: None,
tool_choice: None,
tools: None,
response_format: None,
};
let start_time = std::time::Instant::now();
let response = timeout(
Duration::from_secs(5),
self.http_client
.post(&self.config.base_url)
.header(
"Authorization",
format!("Bearer {}", self.config.api_key.expose_secret()),
)
.header("Content-Type", "application/json")
.json(&test_request)
.send(),
)
.await;
match response {
Ok(Ok(resp)) => {
let latency = start_time.elapsed();
match resp.status() {
StatusCode::OK => Ok(HealthStatus::Healthy { latency }),
status => Ok(HealthStatus::Error {
status: Some(status.as_u16()),
message: format!("HTTP {}", status),
}),
}
}
Ok(Err(_)) => Ok(HealthStatus::Error {
status: None,
message: "HTTP request failed".to_string(),
}),
Err(_) => Ok(HealthStatus::Error {
status: None,
message: "Connection timeout".to_string(),
}),
}
}
fn optimize_for_coordination(&self, mut request: ChatRequest) -> ChatRequest {
request.temperature = request.temperature.min(0.2);
let input_tokens = self.estimate_tokens(&request);
let available_context = self.config.context_budget.saturating_sub(input_tokens);
request.max_tokens = request.max_tokens.min(available_context / 2);
if request.response_format.is_none() {
request.response_format = Some(ResponseFormat::Structured);
}
request
}
fn estimate_tokens(&self, request: &ChatRequest) -> usize {
let content: String = request
.messages
.iter()
.map(|m| m.content.as_str())
.collect();
content.len() / 4
}
async fn send_request(&self, request: &APIRequest) -> Result<reqwest::Response> {
let response = self
.http_client
.post(&self.config.base_url)
.header(
"Authorization",
format!("Bearer {}", self.config.api_key.expose_secret()),
)
.header("Content-Type", "application/json")
.header("HTTP-Referer", "https://reasonkit.sh")
.header("X-Title", "ReasonKit GLM-4.6 Client")
.json(request)
.send()
.await
.map_err(|e| crate::error::Error::Network(e.to_string()))?;
Ok(response)
}
async fn send_stream_request(
&self,
request: APIRequest,
tx: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
) -> Result<()> {
let response = self.send_request(&request).await?;
let bytes_stream = response.bytes_stream();
let mut buffer = String::new();
use futures::stream::StreamExt;
let mut stream = bytes_stream;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| crate::error::Error::Network(e.to_string()))?;
let chunk_str = String::from_utf8_lossy(&chunk);
buffer.push_str(&chunk_str);
while let Some(newline_pos) = buffer.find('\n') {
let line = buffer[..newline_pos].to_string();
buffer = buffer[newline_pos + 1..].to_string();
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return Ok(());
}
match serde_json::from_str::<StreamChunk>(data) {
Ok(chunk) => {
tx.send(chunk).map_err(|_| {
crate::error::Error::Network("Channel closed".to_string())
})?;
}
Err(e) => {
warn!("Failed to parse stream chunk: {:?}\nData: {}", e, data);
}
}
}
}
}
Ok(())
}
async fn parse_response(&self, response: reqwest::Response) -> Result<ChatResponse> {
let status = response.status();
let body = response
.text()
.await
.map_err(|e| crate::error::Error::Network(e.to_string()))?;
if !status.is_success() {
return Err(anyhow::anyhow!("API error {}: {}", status, body));
}
let api_response: APIResponse = serde_json::from_str(&body)
.with_context(|| format!("Failed to parse API response: {}", body))?;
if let Some(error) = api_response.error {
return Err(anyhow::anyhow!("GLM-4.6 API error: {}", error.message));
}
Ok(api_response.into_chat_response())
}
}
#[derive(Debug)]
pub struct CostTracker {
stats: UsageStats,
start_time: std::time::Instant,
}
impl Default for CostTracker {
fn default() -> Self {
Self::new()
}
}
impl CostTracker {
pub fn new() -> Self {
Self {
stats: UsageStats {
total_requests: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_cost: 0.0,
session_start: std::time::SystemTime::now(),
},
start_time: std::time::Instant::now(),
}
}
pub fn record_request(
&mut self,
response: &ChatResponse,
_request: &ChatRequest,
) -> Result<()> {
let input_cost_per_1k = 0.0001; let output_cost_per_1k = 0.0002;
let input_cost = (response.usage.prompt_tokens as f64 / 1000.0) * input_cost_per_1k;
let output_cost = (response.usage.completion_tokens as f64 / 1000.0) * output_cost_per_1k;
let total_cost = input_cost + output_cost;
self.stats.total_requests += 1;
self.stats.total_input_tokens += response.usage.prompt_tokens as u64;
self.stats.total_output_tokens += response.usage.completion_tokens as u64;
self.stats.total_cost += total_cost;
debug!(
"GLM-4.6 request: {} input + {} output tokens, cost: ${:.6}",
response.usage.prompt_tokens, response.usage.completion_tokens, total_cost
);
Ok(())
}
pub fn get_stats(&self) -> UsageStats {
self.stats.clone()
}
pub fn reset(&mut self) {
self.stats = UsageStats {
total_requests: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_cost: 0.0,
session_start: std::time::SystemTime::now(),
};
self.start_time = std::time::Instant::now();
}
}