formal-ai 0.149.0

Formal symbolic AI implementation with OpenAI-compatible APIs
Documentation
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};

use serde::Serialize;
use serde_json::json;

use crate::engine::{is_known_trace_id, knowledge_graph, knowledge_graph_dot, DEFAULT_MODEL};
use crate::protocol::{
    create_chat_completion_with_solver, create_response_with_solver, ChatCompletionRequest,
    ResponsesRequest,
};
use crate::solver::{ExecutionSurface, SolverConfig, UniversalSolver};
use crate::telegram::handle_telegram_webhook;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ApiHttpResponse {
    pub status_code: u16,
    pub content_type: &'static str,
    pub body: String,
}

#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct ApiAuthConfig {
    pub bearer_token: Option<String>,
}

struct ParsedHttpRequest {
    method: String,
    path: String,
    headers: Vec<(String, String)>,
    body: String,
}

impl ApiAuthConfig {
    #[must_use]
    pub fn bearer_token(token: impl Into<String>) -> Self {
        Self {
            bearer_token: Some(token.into()),
        }
    }

    #[must_use]
    pub fn from_env() -> Self {
        Self {
            bearer_token: first_non_empty_env(&[
                "FORMAL_AI_API_BEARER_TOKEN",
                "FORMAL_AI_HTTP_BEARER_TOKEN",
                "FORMAL_AI_API_TOKEN",
            ]),
        }
    }

    #[must_use]
    pub fn allows(&self, headers: &[(&str, &str)]) -> bool {
        let Some(expected) = self.bearer_token.as_deref() else {
            return true;
        };
        bearer_token_from_headers(headers).is_some_and(|actual| actual == expected)
    }
}

#[must_use]
pub fn handle_api_request(method: &str, path: &str, body: &str) -> ApiHttpResponse {
    handle_api_request_with_auth(method, path, &[], body, &ApiAuthConfig::from_env())
}

#[must_use]
pub fn handle_api_request_with_headers(
    method: &str,
    path: &str,
    headers: &[(&str, &str)],
    body: &str,
) -> ApiHttpResponse {
    handle_api_request_with_auth(method, path, headers, body, &ApiAuthConfig::from_env())
}

#[must_use]
pub fn handle_api_request_with_auth(
    method: &str,
    path: &str,
    headers: &[(&str, &str)],
    body: &str,
    auth: &ApiAuthConfig,
) -> ApiHttpResponse {
    let normalized_path = path.split('?').next().unwrap_or(path);
    let query = path.split_once('?').map_or("", |(_, q)| q);

    if requires_bearer_auth(method, normalized_path) && !auth.allows(headers) {
        return error_response(401, "missing or invalid bearer token");
    }

    match (method, normalized_path) {
        ("OPTIONS", _) => ApiHttpResponse {
            status_code: 204,
            content_type: "application/json",
            body: String::new(),
        },
        ("GET", "/health") => json_response(
            200,
            &json!({
                "status": "ok",
                "model": DEFAULT_MODEL,
            }),
        ),
        ("GET", "/v1/models") => json_response(
            200,
            &json!({
                "object": "list",
                "data": [{
                    "id": DEFAULT_MODEL,
                    "object": "model",
                    "created": 0,
                    "owned_by": "link-assistant"
                }],
                "rate_limit": {
                    "requests_per_minute": 60,
                    "tokens_per_minute": 60_000
                }
            }),
        ),
        ("GET", "/v1/graph") => handle_graph_request(query),
        ("POST", "/v1/chat/completions") => {
            match serde_json::from_str::<ChatCompletionRequest>(body) {
                Ok(request) => {
                    let solver = http_solver();
                    if request.stream {
                        sse_response(&create_chat_completion_with_solver(&request, &solver))
                    } else {
                        json_response(200, &create_chat_completion_with_solver(&request, &solver))
                    }
                }
                Err(error) => error_response(400, &format!("invalid chat request: {error}")),
            }
        }
        ("POST", "/v1/responses") => match serde_json::from_str::<ResponsesRequest>(body) {
            Ok(request) => {
                let solver = http_solver();
                json_response(200, &create_response_with_solver(&request, &solver))
            }
            Err(error) => error_response(400, &format!("invalid responses request: {error}")),
        },
        ("POST", "/telegram/webhook") => match handle_telegram_webhook(body) {
            Ok(Some(reply)) => json_response(200, &reply),
            Ok(None) => ApiHttpResponse {
                status_code: 200,
                content_type: "application/json",
                body: String::new(),
            },
            Err(error) => error_response(400, &error.to_string()),
        },
        _ => error_response(404, "route not found"),
    }
}

fn requires_bearer_auth(method: &str, normalized_path: &str) -> bool {
    method != "OPTIONS" && normalized_path.starts_with("/v1/")
}

fn first_non_empty_env(names: &[&str]) -> Option<String> {
    names.iter().find_map(|name| {
        let value = std::env::var(name).ok()?;
        let trimmed = value.trim();
        if trimmed.is_empty() {
            None
        } else {
            Some(trimmed.to_owned())
        }
    })
}

fn bearer_token_from_headers<'a>(headers: &'a [(&str, &str)]) -> Option<&'a str> {
    headers.iter().find_map(|(name, value)| {
        if name.eq_ignore_ascii_case("authorization") {
            parse_bearer_token(value)
        } else {
            None
        }
    })
}

fn parse_bearer_token(value: &str) -> Option<&str> {
    let mut parts = value.split_whitespace();
    let scheme = parts.next()?;
    let token = parts.next()?;
    if parts.next().is_some() || !scheme.eq_ignore_ascii_case("bearer") {
        return None;
    }
    Some(token)
}

fn http_solver() -> UniversalSolver {
    let mut config = SolverConfig::from_env();
    config.execution_surface = ExecutionSurface::HttpServer;
    UniversalSolver::new(config)
}

fn handle_graph_request(query: &str) -> ApiHttpResponse {
    let mut trace: Option<&str> = None;
    let mut format: Option<&str> = None;
    for pair in query.split('&').filter(|part| !part.is_empty()) {
        if let Some((key, value)) = pair.split_once('=') {
            match key {
                "trace" => trace = Some(value),
                "format" => format = Some(value),
                _ => {}
            }
        }
    }

    if let Some(trace_id) = trace {
        if !is_known_trace_id(trace_id) {
            return error_response(404, "unknown trace id");
        }
    }

    if format == Some("dot") {
        return ApiHttpResponse {
            status_code: 200,
            content_type: "text/plain",
            body: knowledge_graph_dot(),
        };
    }

    json_response(200, &knowledge_graph())
}

fn sse_response<T: Serialize>(value: &T) -> ApiHttpResponse {
    let payload = serde_json::to_string(value).unwrap_or_default();
    let body = format!("data: {payload}\n\ndata: [DONE]\n\n");
    ApiHttpResponse {
        status_code: 200,
        content_type: "text/event-stream",
        body,
    }
}

pub fn serve(address: &str) -> std::io::Result<()> {
    let listener = TcpListener::bind(address)?;
    eprintln!("formal-ai server listening on http://{address}");

    for stream in listener.incoming() {
        match stream {
            Ok(mut stream) => {
                if let Err(error) = handle_connection(&mut stream) {
                    eprintln!("request failed: {error}");
                }
            }
            Err(error) => eprintln!("connection failed: {error}"),
        }
    }

    Ok(())
}

fn handle_connection(stream: &mut TcpStream) -> std::io::Result<()> {
    let Some(request) = read_request(stream)? else {
        return Ok(());
    };
    let headers = request
        .headers
        .iter()
        .map(|(name, value)| (name.as_str(), value.as_str()))
        .collect::<Vec<_>>();
    let response =
        handle_api_request_with_headers(&request.method, &request.path, &headers, &request.body);
    write_response(stream, &response)
}

fn read_request(stream: &mut TcpStream) -> std::io::Result<Option<ParsedHttpRequest>> {
    let mut buffer = [0_u8; 8192];
    let bytes_read = stream.read(&mut buffer)?;
    if bytes_read == 0 {
        return Ok(None);
    }

    let mut request_bytes = buffer[..bytes_read].to_vec();
    let header_end = loop {
        if let Some(position) = find_header_end(&request_bytes) {
            break position;
        }
        let bytes_read = stream.read(&mut buffer)?;
        if bytes_read == 0 {
            return Ok(None);
        }
        request_bytes.extend_from_slice(&buffer[..bytes_read]);
    };

    let header_text = String::from_utf8_lossy(&request_bytes[..header_end]).to_string();
    let content_length = content_length(&header_text);
    let body_start = header_end + 4;

    while request_bytes.len() < body_start.saturating_add(content_length) {
        let bytes_read = stream.read(&mut buffer)?;
        if bytes_read == 0 {
            break;
        }
        request_bytes.extend_from_slice(&buffer[..bytes_read]);
    }

    let request_line = header_text.lines().next().unwrap_or_default();
    let mut request_parts = request_line.split_whitespace();
    let method = request_parts.next().unwrap_or_default().to_owned();
    let path = request_parts.next().unwrap_or_default().to_owned();
    let headers = request_headers(&header_text);
    let body_end = body_start
        .saturating_add(content_length)
        .min(request_bytes.len());
    let body = String::from_utf8_lossy(&request_bytes[body_start..body_end]).to_string();

    Ok(Some(ParsedHttpRequest {
        method,
        path,
        headers,
        body,
    }))
}

fn write_response(stream: &mut TcpStream, response: &ApiHttpResponse) -> std::io::Result<()> {
    let status_text = match response.status_code {
        200 => "200 OK",
        204 => "204 No Content",
        400 => "400 Bad Request",
        401 => "401 Unauthorized",
        404 => "404 Not Found",
        _ => "500 Internal Server Error",
    };

    write!(
        stream,
        "HTTP/1.1 {status_text}\r\n\
         content-type: {}\r\n\
         content-length: {}\r\n\
         access-control-allow-origin: *\r\n\
         access-control-allow-methods: GET,POST,OPTIONS\r\n\
         access-control-allow-headers: content-type,authorization\r\n\
         connection: close\r\n\
         \r\n{}",
        response.content_type,
        response.body.len(),
        response.body
    )
}

fn json_response<T: Serialize>(status_code: u16, value: &T) -> ApiHttpResponse {
    match serde_json::to_string_pretty(value) {
        Ok(body) => ApiHttpResponse {
            status_code,
            content_type: "application/json",
            body,
        },
        Err(error) => error_response(500, &format!("failed to serialize response: {error}")),
    }
}

fn error_response(status_code: u16, message: &str) -> ApiHttpResponse {
    ApiHttpResponse {
        status_code,
        content_type: "application/json",
        body: json!({
            "error": {
                "message": message,
                "type": "formal_ai_error"
            }
        })
        .to_string(),
    }
}

fn find_header_end(bytes: &[u8]) -> Option<usize> {
    bytes.windows(4).position(|window| window == b"\r\n\r\n")
}

fn request_headers(headers: &str) -> Vec<(String, String)> {
    headers
        .lines()
        .skip(1)
        .filter_map(|line| {
            let (name, value) = line.split_once(':')?;
            Some((name.trim().to_owned(), value.trim().to_owned()))
        })
        .collect()
}

fn content_length(headers: &str) -> usize {
    headers
        .lines()
        .find_map(|line| {
            let (name, value) = line.split_once(':')?;
            if name.eq_ignore_ascii_case("content-length") {
                value.trim().parse::<usize>().ok()
            } else {
                None
            }
        })
        .unwrap_or(0)
}