use axum::{
extract::State,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use clap::Parser;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};
use tracing::{info, warn};
use uuid::Uuid;
use crate::api::{Claude, Session};
#[derive(Parser, Debug, Clone)]
#[clap(author, version, about = "OpenAI-compatible Claude API server")]
pub struct Args {
#[clap(long, short, default_value = "3000")]
pub port: u16,
#[clap(long)]
pub debug: bool,
#[clap(long, default_value = "claude-3-7-sonnet-20250219")]
pub model: String,
#[clap(long, default_value = "0.0.0.0")]
pub host: String,
}
#[derive(Clone)]
pub struct AppState {
pub claude: Arc<Claude>,
pub conversations: Arc<DashMap<String, (String, Vec<Message>)>>,
}
#[derive(Serialize, Debug)]
pub struct OpenAIResponse {
id: String,
object: &'static str,
created: u64,
model: String,
system_fingerprint: String,
choices: Vec<Choice>,
usage: Usage,
}
#[derive(Serialize, Debug)]
pub struct Usage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct Message {
role: String,
content: ContentValue,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum ContentValue {
String(String),
Array(Vec<String>),
Object(ContentObject),
ContentBlocks(Vec<ContentBlock>),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct ContentObject {
#[serde(default)]
pub text: Option<String>,
#[serde(default)]
pub r#type: Option<String>,
#[serde(flatten)]
pub other: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct ContentBlock {
pub r#type: String,
#[serde(default)]
pub text: Option<String>,
#[serde(flatten)]
pub other: std::collections::HashMap<String, serde_json::Value>,
}
impl Message {
pub fn content_as_string(&self) -> String {
match &self.content {
ContentValue::String(content) => content.clone(),
ContentValue::Array(content) => content.join(" "),
ContentValue::Object(obj) => {
if let Some(text) = &obj.text {
text.clone()
} else {
serde_json::to_string(obj).unwrap_or_default()
}
}
ContentValue::ContentBlocks(blocks) => blocks
.iter()
.filter_map(|block| block.text.as_ref())
.cloned()
.collect::<Vec<_>>()
.join(" "),
}
}
}
#[derive(Debug, Deserialize)]
pub struct MessageInner {
role: String,
content: String,
}
impl From<MessageInner> for Message {
fn from(inner: MessageInner) -> Self {
Self {
role: inner.role,
content: ContentValue::String(inner.content),
}
}
}
mod content_string_or_array {
use serde::{self, Deserialize, Deserializer, Serializer};
pub fn _serialize<S>(content: &str, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(content)
}
pub fn _deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrArray {
String(()),
Array(()),
}
let value = StringOrArray::deserialize(deserializer)?;
match value {
StringOrArray::String(_) => Ok(String::new()),
StringOrArray::Array(_) => Ok(String::new()),
}
}
}
#[derive(Deserialize, Serialize, Debug)]
pub struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
}
#[derive(Serialize, Debug)]
pub struct Choice {
index: i32,
message: Message,
logprobs: Option<serde_json::Value>,
finish_reason: String,
}
fn get_model(model_name: &str) -> &'static str {
match model_name {
"gpt-3.5-turbo" => crate::config::HAIKU_MODEL,
"gpt-4" | "gpt-4-turbo" => crate::config::SONNET_MODEL,
"gpt-4o" | "gpt-4.1" => crate::config::OPUS_MODEL,
_ => crate::config::SONNET_MODEL,
}
}
fn _normalize_role(role: &str) -> String {
match role {
"assistant" => "assistant".to_string(),
"user" => "human".to_string(),
"system" | "developer" => "system".to_string(),
_ => "human".to_string(), }
}
fn _messages_to_prompt(messages: &[Message]) -> String {
let mut prompt = String::new();
for msg in messages {
let role = _normalize_role(&msg.role);
let content = msg.content_as_string();
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n\n", content)),
"human" => prompt.push_str(&format!("Human: {}\n\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n\n", content)),
_ => prompt.push_str(&format!("{}: {}\n\n", role, content)),
}
}
prompt
}
fn conversation_key(messages: &[Message]) -> String {
messages
.iter()
.find(|m| m.role == "user")
.map(|m| m.content_as_string())
.unwrap_or_default()
}
async fn chat_completion_handler(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> impl IntoResponse {
for (i, msg) in request.messages.iter().enumerate() {
match &msg.content {
ContentValue::Array(_) => info!("Message at index {} has array content", i),
ContentValue::Object(_) => info!("Message at index {} has object content", i),
ContentValue::ContentBlocks(_) => info!("Message at index {} has content blocks", i),
_ => {} }
}
let model_id = get_model(&request.model);
if request.messages.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "No messages provided in the request"
})),
);
}
let messages = request.messages;
let conv_key = conversation_key(&messages);
let user_query = messages.last().unwrap().clone();
let claude = state.claude;
let conversations = state.conversations;
let mut conversation_id: Option<String> = None;
let mut message_history: Vec<Message> = Vec::new();
info!("Looking for conversation key: {}", conv_key);
info!("Available conversation keys:");
for item in conversations.iter() {
info!(" Key: '{}', {} chars", item.key(), item.key().len());
}
if let Some(existing) = conversations.get(&conv_key) {
conversation_id = Some(existing.0.clone());
message_history = existing.1.clone();
info!("Found conversation for key: '{}'", conv_key);
} else {
info!("No existing conversation found for key: '{}'", conv_key);
}
if conversation_id.is_none() {
match claude.create_chat().await {
Ok(new_id) => {
conversation_id = Some(new_id);
message_history = messages[..messages.len() - 1].to_vec();
conversations.insert(
conv_key.clone(),
(conversation_id.clone().unwrap(), message_history.clone()),
);
if let Some(system_msg) = messages
.iter()
.find(|m| m.role == "system" || m.role == "developer")
{
if let Err(e) = claude
.send_message(
&conversation_id.clone().unwrap(),
&system_msg.content_as_string(),
&[],
)
.await
{
warn!("Failed to send system prompt: {}", e);
}
}
}
Err(e) => {
let error_response = serde_json::json!({
"error": format!("Failed to create conversation: {}", e)
});
dbg!(&error_response);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
}
}
}
let completion = match claude
.send_message(
&conversation_id.clone().unwrap(),
&user_query.content_as_string(),
&[],
)
.await
{
Ok(response) => response,
Err(e) => {
let error_response = serde_json::json!({
"error": format!("Failed to get completion: {}", e)
});
dbg!(&error_response);
return (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response));
}
};
let new_message = Message {
role: "assistant".to_string(),
content: ContentValue::String(completion.clone()),
};
let mut updated_history = message_history.clone();
updated_history.push(user_query.clone());
updated_history.push(new_message);
if let Some(_) = conversations.remove(&conv_key) {
info!("Removed old conversation entry with key: '{}'", conv_key);
}
let new_conv_key = conversation_key(&updated_history);
let conv_id = conversation_id.unwrap();
conversations.insert(
new_conv_key.clone(),
(conv_id.clone(), updated_history.clone()),
);
info!(
"Stored conversation with ID: {} under new key '{}', {} chars",
conv_id,
new_conv_key,
new_conv_key.len()
);
let response = OpenAIResponse {
id: Uuid::new_v4().to_string(),
object: "chat.completion",
created: chrono::Utc::now().timestamp() as u64,
model: model_id.to_string(),
system_fingerprint: "fp_toast".to_string(),
choices: vec![Choice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: ContentValue::String(completion.clone()),
},
logprobs: None,
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: user_query.content_as_string().chars().count() as u32 / 4, completion_tokens: completion.chars().count() as u32 / 4, total_tokens: (user_query.content_as_string().chars().count()
+ completion.chars().count()) as u32
/ 4,
},
};
(
StatusCode::OK,
Json(serde_json::to_value(response).unwrap()),
)
}
async fn health_check() -> impl IntoResponse {
(StatusCode::OK, "OK")
}
pub async fn run() -> anyhow::Result<()> {
run_with_args(Args::parse()).await
}
pub async fn run_with_args(args: Args) -> anyhow::Result<()> {
if args.debug {
tracing_subscriber::fmt::init();
} else {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::WARN)
.init();
}
let config_dir = dirs::config_dir()
.ok_or_else(|| anyhow::anyhow!("Could not determine config directory"))?
.join("toast");
let cookie_path = config_dir.join("cookie");
let org_id_path = config_dir.join("org_id");
if !config_dir.exists() {
return Err(anyhow::anyhow!(
"Configuration directory does not exist at {:?}",
config_dir
));
}
let cookie = if cookie_path.exists() {
std::fs::read_to_string(&cookie_path)?.trim().to_string()
} else {
return Err(anyhow::anyhow!(
"Cookie file not found at {:?}",
cookie_path
));
};
let org_id = if org_id_path.exists() {
std::fs::read_to_string(&org_id_path)?.trim().to_string()
} else {
if let Some(extracted_org_id) = crate::cli::extract_org_id_from_cookie(&cookie) {
std::fs::write(&org_id_path, &extracted_org_id)?;
info!(
"Extracted organization ID from cookie and saved to {:?}",
org_id_path
);
extracted_org_id
} else {
return Err(anyhow::anyhow!(
"Organization ID file not found at {:?} and couldn't extract it from cookie",
org_id_path
));
}
};
let user_agent =
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:137.0) Gecko/20100101 Firefox/137.0"
.to_string();
let session = Session {
cookie,
user_agent,
organization_id: org_id,
};
let model_str = Box::leak(args.model.into_boxed_str());
let claude = Arc::new(Claude::new(session.clone(), model_str)?);
let app_state = AppState {
claude,
conversations: Arc::new(DashMap::new()),
};
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/health", get(health_check))
.route("/v1/chat/completions", post(chat_completion_handler))
.route("/", post(chat_completion_handler))
.route("/chat/completions", post(chat_completion_handler))
.layer(cors)
.with_state(app_state);
let host_parts: Vec<u8> = args
.host
.split('.')
.map(|s| s.parse::<u8>().unwrap_or(0))
.collect();
let host_addr = if host_parts.len() == 4 {
[host_parts[0], host_parts[1], host_parts[2], host_parts[3]]
} else {
[0, 0, 0, 0]
};
let addr = SocketAddr::from((host_addr, args.port));
info!("Server listening on {}", addr);
info!("Example request:");
info!(
"curl -X POST http://localhost:{}{} \\",
args.port, "/v1/chat/completions # Also available at / and /chat/completions"
);
info!(" -H \"Content-Type: application/json\" \\");
info!(" -d '{{");
info!(" \"model\": \"gpt-4.1\",");
info!(" \"messages\": [");
info!(" {{");
info!(" \"role\": \"developer\",");
info!(" \"content\": \"You are a helpful assistant.\"");
info!(" }},");
info!(" {{");
info!(" \"role\": \"user\",");
info!(" \"content\": \"Hello! (Simple string content)\"");
info!(" }},");
info!(" {{");
info!(" \"role\": \"user\",");
info!(" \"content\": [\"Array of\", \"string content\"]");
info!(" }},");
info!(" {{");
info!(" \"role\": \"user\",");
info!(" \"content\": {{\"type\": \"text\", \"text\": \"Object with text field\"}}");
info!(" }},");
info!(" {{");
info!(" \"role\": \"user\",");
info!(" \"content\": [");
info!(" {{\"type\": \"text\", \"text\": \"Array of content blocks\"}},");
info!(" {{\"type\": \"text\", \"text\": \"with multiple entries\"}}");
info!(" ]");
info!(" }}");
info!(" ]");
info!(" }}'");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}