Documentation
use futures::stream::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::{collections::HashMap, error::Error};
use tokio::{
    io::{self as tokio_io, AsyncWriteExt, BufReader},
};
use tokio_util::codec::{FramedRead, LinesCodec};

// MCP message types
#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcRequest {
    pub jsonrpc: String,
    #[serde(default = "default_id")]
    pub id: Value,
    pub method: String,
    pub params: Option<Value>,
}

fn default_id() -> Value {
    Value::Null
}

#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcResponse {
    pub jsonrpc: String,
    pub id: Value,
    pub result: Option<Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub error: Option<JsonRpcError>,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcError {
    pub code: i32,
    pub message: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub data: Option<Value>,
}

// Tool handler type
pub type ToolHandler = Box<dyn Fn(Value) -> tokio::task::JoinHandle<Result<Value, String>> + Send + Sync>;

pub struct McpServer {
    name: String,
    version: String,
    tools: HashMap<String, (String, Value, ToolHandler)>,
}

impl McpServer {
    pub fn new(name: &str, version: &str) -> Self {
        McpServer {
            name: name.to_string(),
            version: version.to_string(),
            tools: HashMap::new(),
        }
    }

    pub fn add_tool<F, Fut>(&mut self, name: &str, description: &str, input_schema: Value, handler: F)
    where
        F: Fn(Value) -> Fut + Send + Sync + Clone + 'static,
        Fut: std::future::Future<Output = Result<Value, String>> + Send + 'static,
    {
        let handler_boxed: ToolHandler = Box::new(move |args| {
            let handler_clone = handler.clone();
            tokio::spawn(async move {
                handler_clone(args).await
            })
        });
        
        self.tools.insert(
            name.to_string(),
            (description.to_string(), input_schema, handler_boxed),
        );
    }

    async fn handle_message(&self, msg: String) -> Option<String> {
        eprintln!("Received message: {}", msg);
        
        let request: JsonRpcRequest = match serde_json::from_str(&msg) {
            Ok(req) => req,
            Err(e) => {
                eprintln!("Error parsing request: {}", e);
                let error_response = JsonRpcResponse {
                    jsonrpc: "2.0".to_string(),
                    id: Value::Null,
                    result: None,
                    error: Some(JsonRpcError {
                        code: -32700,
                        message: "Parse error".to_string(),
                        data: Some(json!({"error": e.to_string()}))
                    })
                };
                let response_str = serde_json::to_string(&error_response).unwrap();
                eprintln!("Sending error response: {}", response_str);
                return Some(response_str);
            }
        };

        // Handle initialization
        if request.method == "initialize" {
            let response = JsonRpcResponse {
                jsonrpc: "2.0".to_string(),
                id: request.id,
                result: Some(json!({
                    "protocolVersion": "2024-11-05",
                    "capabilities": {
                        "tools": {
                            "listChanged": true
                        }
                    },
                    "serverInfo": {
                        "name": self.name,
                        "version": self.version
                    }
                })),
                error: None,
            };
            let response_str = serde_json::to_string(&response).unwrap();
            eprintln!("Sending initialize response: {}", response_str);
            return Some(response_str);
        }

        // Handle tools/list
        if request.method == "tools/list" {
            let tools: Vec<Value> = self.tools
                .iter()
                .map(|(name, (description, input_schema, _))| {
                    json!({
                        "name": name,
                        "description": description,
                        "inputSchema": input_schema,
                    })
                })
                .collect();

            let response = JsonRpcResponse {
                jsonrpc: "2.0".to_string(),
                id: request.id,
                result: Some(json!({
                    "tools": tools,
                })),
                error: None,
            };
            return Some(serde_json::to_string(&response).unwrap());
        }
        
        // Handle resources/list
        if request.method == "resources/list" {
            let response = JsonRpcResponse {
                jsonrpc: "2.0".to_string(),
                id: request.id,
                result: Some(json!({
                    "resources": []
                })),
                error: None,
            };
            return Some(serde_json::to_string(&response).unwrap());
        }
        
        // Handle prompts/list
        if request.method == "prompts/list" {
            let response = JsonRpcResponse {
                jsonrpc: "2.0".to_string(),
                id: request.id,
                result: Some(json!({
                    "prompts": []
                })),
                error: None,
            };
            return Some(serde_json::to_string(&response).unwrap());
        }

        // Handle tools/call
        if request.method == "tools/call" {
            let params = match request.params {
                Some(p) => p,
                None => {
                    let error_response = JsonRpcResponse {
                        jsonrpc: "2.0".to_string(),
                        id: request.id,
                        result: None,
                        error: Some(JsonRpcError {
                            code: -32602,
                            message: "Invalid params".to_string(),
                            data: None
                        })
                    };
                    return Some(serde_json::to_string(&error_response).unwrap());
                }
            };

            let tool_name = match params.get("name") {
                Some(Value::String(name)) => name,
                _ => {
                    let error_response = JsonRpcResponse {
                        jsonrpc: "2.0".to_string(),
                        id: request.id,
                        result: None,
                        error: Some(JsonRpcError {
                            code: -32602,
                            message: "Tool name required".to_string(),
                            data: None
                        })
                    };
                    return Some(serde_json::to_string(&error_response).unwrap());
                }
            };

            let arguments = match params.get("arguments") {
                Some(args) => args.clone(),
                None => json!({}),
            };

            if let Some((_, _, handler)) = self.tools.get(tool_name) {
                let handle = handler(arguments);
                match handle.await.unwrap() {
                    Ok(val) => {
                        let response = JsonRpcResponse {
                            jsonrpc: "2.0".to_string(),
                            id: request.id,
                            result: Some(json!({
                                "content": [
                                    {
                                        "type": "text",
                                        "text": val.to_string()
                                    }
                                ],
                                "isError": false
                            })),
                            error: None,
                        };
                        return Some(serde_json::to_string(&response).unwrap());
                    },
                    Err(e) => {
                        let error_response = JsonRpcResponse {
                            jsonrpc: "2.0".to_string(),
                            id: request.id,
                            result: None,
                            error: Some(JsonRpcError {
                                code: -32000,  // Internal tool execution error
                                message: e,
                                data: None
                            })
                        };
                        return Some(serde_json::to_string(&error_response).unwrap());
                    },
                };
            } else {
                let error_response = JsonRpcResponse {
                    jsonrpc: "2.0".to_string(),
                    id: request.id,
                    result: None,
                    error: Some(JsonRpcError {
                        code: -32601,
                        message: format!("Tool not found: {}", tool_name),
                        data: None
                    })
                };
                return Some(serde_json::to_string(&error_response).unwrap());
            }
        }

        // For notifications, the client doesn't expect a response
        if request.method.starts_with("notifications/") {
            eprintln!("Handling notification: {}, not sending a response", request.method);
            return None;
        }

        // Method not found
        eprintln!("Unknown method: {}", request.method);
        let error_response = JsonRpcResponse {
            jsonrpc: "2.0".to_string(),
            id: request.id,
            result: None,
            error: Some(JsonRpcError {
                code: -32601,
                message: format!("Method not found: {}", request.method),
                data: None
            })
        };
        let response_str = serde_json::to_string(&error_response).unwrap();
        eprintln!("Sending method not found response: {}", response_str);
        Some(response_str)
    }

    pub async fn run_stdio(self) -> Result<(), Box<dyn Error>> {
        let stdin = tokio_io::stdin();
        let mut stdout = tokio_io::stdout();
        let reader = BufReader::new(stdin);
        let mut lines = FramedRead::new(reader, LinesCodec::new());

        while let Some(line_result) = lines.next().await {
            match line_result {
                Ok(line) => {
                    eprintln!("Processing line: {}", line);
                    if let Some(response) = self.handle_message(line).await {
                        eprintln!("Sending response: {}", response);
                        stdout.write_all(response.as_bytes()).await?;
                        stdout.write_all(b"\n").await?;
                        stdout.flush().await?;
                    } else {
                        eprintln!("No response to send (notification handling)");
                    }
                }
                Err(e) => eprintln!("Error reading line: {}", e),
            }
        }

        Ok(())
    }
}