atomfn 0.1.2

AtomService 函数服务 Rust SDK:与 TS SDK 协议一致的常驻 HTTP 运行时
Documentation
use std::collections::HashMap;
use std::io::{BufRead, BufReader, Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use std::thread;

use serde::Deserialize;
use serde_json::{json, Value};

use crate::context::{generate_request_id, Ctx};
use crate::registry::{FunctionConfig, Outcome, Registry};

const DEFAULT_PORT: u16 = 8080;

struct ServerState {
    registry: Registry,
    project: String,
    bundle: String,
}

struct ParsedRequest {
    method: String,
    path: String,
    query: Option<String>,
    request_id: Option<String>,
    body: Vec<u8>,
}

#[derive(Deserialize)]
struct ConfigFile {
    #[serde(default)]
    functions: HashMap<String, FunctionConfig>,
}

fn load_config_file() -> HashMap<String, FunctionConfig> {
    let path = std::env::var("ATOMFN_CONFIG")
        .unwrap_or_else(|_| "/app/functions.config.json".to_string());
    match std::fs::read_to_string(&path) {
        Ok(content) => match serde_json::from_str::<ConfigFile>(&content) {
            Ok(cfg) => cfg.functions,
            Err(e) => {
                eprintln!("警告:配置文件解析失败 {}: {}", path, e);
                HashMap::new()
            }
        },
        Err(_) => HashMap::new(),
    }
}

pub fn serve(mut registry: Registry) {
    let port: u16 = std::env::var("ATOMFN_PORT")
        .ok()
        .and_then(|value| value.parse().ok())
        .unwrap_or(DEFAULT_PORT);
    let project = std::env::var("ATOMFN_PROJECT").unwrap_or_else(|_| "default".to_string());
    let bundle = std::env::var("ATOMFN_BUNDLE").unwrap_or_else(|_| "default".to_string());

    registry.load_configs(load_config_file());

    let state = Arc::new(ServerState {
        registry,
        project,
        bundle,
    });
    let listener = TcpListener::bind(("0.0.0.0", port)).expect("无法绑定端口");
    println!("atomfn 运行时已启动 :{}", port);

    for stream in listener.incoming().flatten() {
        let state = Arc::clone(&state);
        thread::spawn(move || {
            let _ = handle_connection(stream, state);
        });
    }
}

fn handle_connection(stream: TcpStream, state: Arc<ServerState>) -> std::io::Result<()> {
    let mut writer = stream.try_clone()?;
    let request = match parse_request(stream) {
        Some(request) => request,
        None => return write_simple(&mut writer, 400, "Bad Request", "bad request"),
    };

    if request.method == "GET" && request.path == "/__health" {
        let body = json!({ "ok": true, "functions": state.registry.functions() });
        return write_json(&mut writer, 200, &body);
    }

    if request.method == "POST" {
        if let Some(name) = request.path.strip_prefix("/invoke/") {
            return route_invoke(&mut writer, &state, name, &request);
        }
    }

    write_simple(&mut writer, 404, "Not Found", "not found")
}

fn urldecode(s: &str) -> String {
    let mut result = String::new();
    let mut chars = s.chars();
    while let Some(c) = chars.next() {
        if c == '%' {
            let hex: String = chars.by_ref().take(2).collect();
            if hex.len() == 2 {
                if let Ok(byte) = u8::from_str_radix(&hex, 16) {
                    result.push(byte as char);
                    continue;
                }
            }
            result.push('%');
            result.push_str(&hex);
        } else if c == '+' {
            result.push(' ');
        } else {
            result.push(c);
        }
    }
    result
}

fn route_invoke(
    writer: &mut TcpStream,
    state: &ServerState,
    name: &str,
    request: &ParsedRequest,
) -> std::io::Result<()> {
    // 解析 query 参数
    let mut event = json!({});
    if let Some(query) = &request.query {
        for pair in query.split('&') {
            if let Some((k, v)) = pair.split_once('=') {
                let key = urldecode(k);
                let val = urldecode(v);
                event[key] = Value::String(val);
            }
        }
    }
    
    if !request.body.is_empty() {
        if let Ok(Value::Object(body_map)) = serde_json::from_slice::<Value>(&request.body) {
            if let Value::Object(ref mut event_map) = event {
                for (k, v) in body_map {
                    event_map.insert(k, v);
                }
            }
        }
    }
    let request_id = request
        .request_id
        .clone()
        .unwrap_or_else(generate_request_id);
    let ctx = Ctx {
        request_id: request_id.clone(),
        project: state.project.clone(),
        bundle: state.bundle.clone(),
        function: name.to_string(),
    };

    if let Some(handler) = state.registry.one_shot.get(name) {
        let outcome = handler(event, &ctx);
        let (status, body) = map_outcome(outcome);
        return write_json_with_id(writer, status, &body, &request_id);
    }

    if let Some(handler) = state.registry.stream.get(name) {
        return write_stream(writer, handler(event, &ctx), &request_id);
    }

    let body = system_error("NOT_FOUND", &format!("未找到函数 {}", name), false);
    write_json_with_id(writer, 404, &body, &request_id)
}

fn map_outcome(outcome: Outcome) -> (u16, Value) {
    match outcome {
        Outcome::Ok(data) => (200, json!({ "ok": true, "data": data })),
        Outcome::Business { code, data } => (
            200,
            json!({ "ok": false, "error": { "type": "business", "code": code, "data": data } }),
        ),
        Outcome::Validation(message) => (422, system_error("VALIDATION_FAILED", &message, false)),
        Outcome::Crash(message) => (500, system_error("FUNCTION_ERROR", &message, false)),
    }
}

fn system_error(code: &str, message: &str, retryable: bool) -> Value {
    json!({
        "ok": false,
        "error": { "type": "system", "code": code, "message": message, "retryable": retryable }
    })
}

fn write_stream(
    writer: &mut TcpStream,
    result: Result<crate::registry::StreamItems, Outcome>,
    request_id: &str,
) -> std::io::Result<()> {
    match result {
        Err(Outcome::Validation(message)) => {
            let body = system_error("VALIDATION_FAILED", &message, false);
            write_json_with_id(writer, 422, &body, request_id)
        }
        Err(outcome) => {
            write_sse_head(writer, request_id)?;
            write_sse_event(writer, "error", &outcome_error_value(outcome))?;
            Ok(())
        }
        Ok(items) => {
            write_sse_head(writer, request_id)?;
            for item in items {
                write_sse_event(writer, "chunk", &item)?;
            }
            write_sse_event(writer, "done", &json!({}))
        }
    }
}

fn outcome_error_value(outcome: Outcome) -> Value {
    match outcome {
        Outcome::Business { code, data } => {
            json!({ "type": "business", "code": code, "data": data })
        }
        Outcome::Crash(message) => {
            json!({ "type": "system", "code": "FUNCTION_ERROR", "message": message, "retryable": false })
        }
        Outcome::Validation(message) => {
            json!({ "type": "system", "code": "VALIDATION_FAILED", "message": message, "retryable": false })
        }
        Outcome::Ok(_) => {
            json!({ "type": "system", "code": "FUNCTION_ERROR", "message": "", "retryable": false })
        }
    }
}

fn parse_request(stream: TcpStream) -> Option<ParsedRequest> {
    let mut reader = BufReader::new(stream);

    let mut request_line = String::new();
    if reader.read_line(&mut request_line).ok()? == 0 {
        return None;
    }
    let mut parts = request_line.split_whitespace();
    let method = parts.next()?.to_string();
    let full_path = parts.next()?.to_string();
    
    let (path, query) = if let Some(idx) = full_path.find('?') {
        (full_path[..idx].to_string(), Some(full_path[idx + 1..].to_string()))
    } else {
        (full_path, None)
    };

    let mut content_length = 0usize;
    let mut request_id = None;
    loop {
        let mut line = String::new();
        if reader.read_line(&mut line).ok()? == 0 {
            break;
        }
        let trimmed = line.trim_end_matches(['\r', '\n']);
        if trimmed.is_empty() {
            break;
        }
        if let Some((key, value)) = trimmed.split_once(':') {
            let key = key.trim().to_ascii_lowercase();
            let value = value.trim().to_string();
            if key == "content-length" {
                content_length = value.parse().unwrap_or(0);
            } else if key == "x-atomfn-request-id" {
                request_id = Some(value);
            }
        }
    }

    let mut body = vec![0u8; content_length];
    if content_length > 0 {
        reader.read_exact(&mut body).ok()?;
    }

    Some(ParsedRequest {
        method,
        path,
        query,
        request_id,
        body,
    })
}

fn status_text(status: u16) -> &'static str {
    match status {
        200 => "OK",
        404 => "Not Found",
        422 => "Unprocessable Entity",
        500 => "Internal Server Error",
        _ => "OK",
    }
}

fn write_json(writer: &mut TcpStream, status: u16, body: &Value) -> std::io::Result<()> {
    let payload = body.to_string();
    let response = format!(
        "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
        status,
        status_text(status),
        payload.len(),
        payload
    );
    writer.write_all(response.as_bytes())
}

fn write_json_with_id(
    writer: &mut TcpStream,
    status: u16,
    body: &Value,
    request_id: &str,
) -> std::io::Result<()> {
    let payload = body.to_string();
    let response = format!(
        "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nx-atomfn-request-id: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
        status,
        status_text(status),
        request_id,
        payload.len(),
        payload
    );
    writer.write_all(response.as_bytes())
}

fn write_simple(
    writer: &mut TcpStream,
    status: u16,
    text: &str,
    body: &str,
) -> std::io::Result<()> {
    let response = format!(
        "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
        status,
        text,
        body.len(),
        body
    );
    writer.write_all(response.as_bytes())
}

fn write_sse_head(writer: &mut TcpStream, request_id: &str) -> std::io::Result<()> {
    let head = format!(
        "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nCache-Control: no-cache\r\nx-atomfn-request-id: {}\r\nConnection: keep-alive\r\n\r\n",
        request_id
    );
    writer.write_all(head.as_bytes())
}

fn write_sse_event(writer: &mut TcpStream, event: &str, data: &Value) -> std::io::Result<()> {
    let frame = format!("event: {}\ndata: {}\n\n", event, data);
    writer.write_all(frame.as_bytes())?;
    writer.flush()
}