Skip to main content

indodax_cli/mcp/tools/
mod.rs

1pub mod account;
2pub mod auth;
3pub mod funding;
4pub mod market;
5pub mod paper;
6pub mod trade;
7
8use std::sync::Arc;
9
10use tokio::sync::Mutex;
11use serde_json::{Map, Value};
12
13use rmcp::model::{
14    CallToolRequestParams, CallToolResult, Content, Implementation, InitializeResult,
15    ListToolsResult, PaginatedRequestParams, ServerCapabilities, Tool,
16};
17use rmcp::service::{RequestContext, RoleServer};
18use rmcp::ErrorData as McpError;
19
20use crate::commands::helpers;
21use crate::config::IndodaxConfig;
22use crate::errors::IndodaxError;
23use crate::mcp::safety::SafetyConfig;
24use crate::mcp::service::ServiceGroup;
25use crate::client::IndodaxClient;
26
27/// The MCP server exposing Indodax trading functionality as tools.
28#[derive(Debug, Clone)]
29pub struct IndodaxMcp {
30    pub client: Arc<IndodaxClient>,
31    pub config: Arc<Mutex<IndodaxConfig>>,
32    pub safety: SafetyConfig,
33    pub enabled_groups: Vec<ServiceGroup>,
34    pub paper_mutex: Arc<tokio::sync::Mutex<()>>,
35}
36
37impl IndodaxMcp {
38    pub fn new(
39        client: IndodaxClient,
40        config: IndodaxConfig,
41        safety: SafetyConfig,
42        enabled_groups: Vec<ServiceGroup>,
43    ) -> Self {
44        Self {
45            client: Arc::new(client),
46            config: Arc::new(Mutex::new(config)),
47            safety,
48            enabled_groups,
49            paper_mutex: Arc::new(tokio::sync::Mutex::new(())),
50        }
51    }
52
53    pub fn is_group_enabled(&self, group: &ServiceGroup) -> bool {
54        self.enabled_groups.contains(group)
55    }
56
57    pub fn str_param(description: &str, _required: bool, default_: Option<&str>) -> Value {
58        let mut schema = serde_json::json!({
59            "type": "string",
60            "description": description,
61        });
62        if let Some(d) = default_ {
63            schema["default"] = Value::String(d.to_string());
64        }
65        schema
66    }
67
68    pub fn num_param(description: &str, _required: bool) -> Value {
69        serde_json::json!({
70            "type": "number",
71            "description": description,
72        })
73    }
74
75    pub fn bool_param(description: &str) -> Value {
76        serde_json::json!({
77            "type": "boolean",
78            "description": description,
79        })
80    }
81
82    /// Helper to get account info from Indodax API.
83    pub async fn get_account_info(&self) -> Result<Value, IndodaxError> {
84        self.client
85            .private_post_v1::<Value>("getInfo", &std::collections::HashMap::new())
86            .await
87    }
88
89    pub fn tool_def(name: &str, description: &str, properties: Value, required: Vec<&str>) -> Tool {
90        let mut schema = Map::new();
91        schema.insert("type".to_string(), Value::String("object".to_string()));
92
93        if let Value::Object(props) = properties {
94            if !props.is_empty() {
95                schema.insert("properties".to_string(), Value::Object(props));
96            }
97        }
98
99        if !required.is_empty() {
100            let req_values: Vec<Value> = required
101                .iter()
102                .map(|s| Value::String(s.to_string()))
103                .collect();
104            schema.insert("required".to_string(), Value::Array(req_values));
105        }
106
107        Tool::new(name.to_string(), description.to_string(), Arc::new(schema))
108    }
109
110    pub fn get_str(args: &Map<String, Value>, name: &str) -> Option<String> {
111        args.get(name)
112            .and_then(|v| v.as_str())
113            .map(|s| s.to_string())
114    }
115
116    pub fn get_num(args: &Map<String, Value>, name: &str) -> Option<f64> {
117        args.get(name).and_then(|v| {
118            v.as_f64()
119                .or_else(|| v.as_str().and_then(|s| s.parse::<f64>().ok()))
120        })
121    }
122
123    pub fn get_bool(args: &Map<String, Value>, name: &str) -> bool {
124        Self::get_opt_bool(args, name).unwrap_or(false)
125    }
126
127    pub fn get_opt_bool(args: &Map<String, Value>, name: &str) -> Option<bool> {
128        args.get(name).and_then(|v| v.as_bool())
129    }
130
131    pub fn ok_result(text: String) -> CallToolResult {
132        CallToolResult::success(vec![Content::text(text)])
133    }
134
135    pub fn error_result(text: String) -> CallToolResult {
136        let envelope = serde_json::json!({
137            "error": true,
138            "message": text,
139            "error_type": "mcp_error",
140        });
141        CallToolResult::error(vec![Content::text(
142            serde_json::to_string_pretty(&envelope).unwrap_or(text),
143        )])
144    }
145
146    pub fn validation_error_result(text: String) -> CallToolResult {
147        let envelope = serde_json::json!({
148            "error": true,
149            "message": text,
150            "error_type": "validation_error",
151        });
152        CallToolResult::error(vec![Content::text(
153            serde_json::to_string_pretty(&envelope).unwrap_or(text),
154        )])
155    }
156
157    pub fn error_from_indodax(err: &IndodaxError) -> CallToolResult {
158        let envelope = serde_json::json!({
159            "error": true,
160            "message": err.to_string(),
161            "error_type": err.category(),
162        });
163        CallToolResult::error(vec![Content::text(
164            serde_json::to_string_pretty(&envelope).unwrap_or_else(|_| err.to_string()),
165        )])
166    }
167
168    pub fn json_result(value: Value) -> CallToolResult {
169        let text = serde_json::to_string_pretty(&value).unwrap_or_default();
170        Self::ok_result(text)
171    }
172
173    pub fn json_result_with_warning(mut value: Value, warning: Option<String>) -> CallToolResult {
174        if let Some(w) = warning {
175            if let Some(obj) = value.as_object_mut() {
176                obj.insert("warning".to_string(), Value::String(w));
177            }
178        }
179        Self::json_result(value)
180    }
181}
182
183impl rmcp::handler::server::ServerHandler for IndodaxMcp {
184    fn get_info(&self) -> InitializeResult {
185        InitializeResult::new(
186            ServerCapabilities::builder()
187                .enable_tools()
188                .build(),
189        )
190        .with_server_info(Implementation::new(
191            "indodax-cli",
192            env!("CARGO_PKG_VERSION"),
193        ))
194    }
195
196    async fn list_tools(
197        &self,
198        _request: Option<PaginatedRequestParams>,
199        _context: RequestContext<RoleServer>,
200    ) -> Result<ListToolsResult, McpError> {
201        let tools = self.all_tools();
202        Ok(ListToolsResult::with_all_items(tools))
203    }
204
205    async fn call_tool(
206        &self,
207        request: CallToolRequestParams,
208        _context: RequestContext<RoleServer>,
209    ) -> Result<CallToolResult, McpError> {
210        let name = request.name.to_string();
211        let args = request.arguments.unwrap_or_default();
212
213        let result = match name.as_str() {
214            // Market (public, no auth needed)
215            "server_time" => self.handle_server_time().await,
216            "ticker" => {
217                let pair = helpers::normalize_pair(
218                    &Self::get_str(&args, "pair").unwrap_or_else(|| "btc_idr".into())
219                );
220                self.handle_ticker(&pair).await
221            }
222            "ticker_all" => self.handle_ticker_all().await,
223            "pairs" => self.handle_pairs().await,
224            "summaries" => self.handle_summaries().await,
225            "orderbook" => {
226                let pair = helpers::normalize_pair(
227                    &Self::get_str(&args, "pair").unwrap_or_else(|| "btc_idr".into())
228                );
229                self.handle_orderbook(&pair).await
230            }
231            "trades" => {
232                let pair = helpers::normalize_pair(
233                    &Self::get_str(&args, "pair").unwrap_or_else(|| "btc_idr".into())
234                );
235                self.handle_trades(&pair).await
236            }
237            "ohlc" => {
238                let symbol = helpers::normalize_pair(
239                    &Self::get_str(&args, "symbol").unwrap_or_else(|| "btc_idr".into())
240                ).replace('_', "").to_uppercase();
241                let timeframe =
242                    Self::get_str(&args, "timeframe").unwrap_or_else(|| "60".into());
243                let from = Self::get_num(&args, "from");
244                let to = Self::get_num(&args, "to");
245                self.handle_ohlc(&symbol, &timeframe, from, to).await
246            }
247            "price_increments" => self.handle_price_increments().await,
248
249            // Trade
250            "buy_order" => {
251                let acknowledged = Self::get_bool(&args, "acknowledged");
252                if let Err(msg) =
253                    self.safety.check_operation(&ServiceGroup::Trade, acknowledged)
254                {
255                    return Ok(Self::error_result(msg));
256                }
257                let pair = helpers::normalize_pair(
258                    &Self::get_str(&args, "pair").unwrap_or_default()
259                );
260                let idr = Self::get_num(&args, "idr").unwrap_or(0.0);
261                let price = Self::get_num(&args, "price");
262                self.handle_buy_order(&pair, idr, price).await
263            }
264            "sell_order" => {
265                let acknowledged = Self::get_bool(&args, "acknowledged");
266                if let Err(msg) =
267                    self.safety.check_operation(&ServiceGroup::Trade, acknowledged)
268                {
269                    return Ok(Self::error_result(msg));
270                }
271                let pair = helpers::normalize_pair(
272                    &Self::get_str(&args, "pair").unwrap_or_default()
273                );
274                let price = Self::get_num(&args, "price");
275                let amount = match Self::get_num(&args, "amount") {
276                    Some(v) if v > 0.0 => v,
277                    Some(v) => return Ok(Self::validation_error_result(format!("Amount must be positive, got {}", v))),
278                    None => return Ok(Self::validation_error_result("Missing required parameter: amount".into())),
279                };
280                let order_type = Self::get_str(&args, "order_type").unwrap_or_else(|| "limit".into());
281                self.handle_sell_order(&pair, price, amount, &order_type).await
282            }
283
284            // Account
285            "account_info" => self.handle_account_info().await,
286            "balance" => self.handle_balance().await,
287            "open_orders" => {
288                let pair = Self::get_str(&args, "pair").map(|p| helpers::normalize_pair(&p));
289                self.handle_open_orders(pair.as_deref()).await
290            }
291            "order_history" => {
292                let symbol = helpers::normalize_pair(
293                    &Self::get_str(&args, "symbol").unwrap_or_else(|| "btc_idr".into())
294                );
295                let limit = Self::get_num(&args, "limit");
296                self.handle_order_history(&symbol, limit).await
297            }
298            "trade_history" => {
299                let symbol = helpers::normalize_pair(
300                    &Self::get_str(&args, "symbol").unwrap_or_else(|| "btc_idr".into())
301                );
302                let limit = Self::get_num(&args, "limit");
303                self.handle_trade_history(&symbol, limit).await
304            }
305            "get_order" => {
306                let order_id = match Self::get_num(&args, "order_id") {
307                    Some(v) => {
308                        if v.fract() != 0.0 {
309                            return Ok(Self::validation_error_result(format!("order_id must be a whole number, got {}", v)));
310                        }
311                        v
312                    }
313                    None => return Ok(Self::validation_error_result("Missing required parameter: order_id".into())),
314                };
315                let pair = match Self::get_str(&args, "pair") {
316                    Some(v) => helpers::normalize_pair(&v),
317                    None => return Ok(Self::validation_error_result("Missing required parameter: pair".into())),
318                };
319                self.handle_get_order(order_id, &pair).await
320            }
321            "trans_history" => self.handle_trans_history().await,
322            "cancel_order" => {
323                let acknowledged = Self::get_bool(&args, "acknowledged");
324                if let Err(msg) =
325                    self.safety.check_operation(&ServiceGroup::Trade, acknowledged)
326                {
327                    return Ok(Self::error_result(msg));
328                }
329                let order_id = match Self::get_num(&args, "order_id") {
330                    Some(v) => {
331                        if v.fract() != 0.0 {
332                            return Ok(Self::validation_error_result(format!("order_id must be a whole number, got {}", v)));
333                        }
334                        v
335                    }
336                    None => return Ok(Self::validation_error_result("Missing required parameter: order_id".into())),
337                };
338                let pair = match Self::get_str(&args, "pair") {
339                    Some(v) => helpers::normalize_pair(&v),
340                    None => return Ok(Self::validation_error_result("Missing required parameter: pair".into())),
341                };
342                let order_type = Self::get_str(&args, "order_type").unwrap_or_default();
343                self.handle_cancel_order(order_id, &pair, &order_type).await
344            }
345            "cancel_all_orders" => {
346                let acknowledged = Self::get_bool(&args, "acknowledged");
347                if let Err(msg) =
348                    self.safety.check_operation(&ServiceGroup::Trade, acknowledged)
349                {
350                    return Ok(Self::error_result(msg));
351                }
352                let pair = Self::get_str(&args, "pair")
353                    .map(|p| helpers::normalize_pair(&p));
354                self.handle_cancel_all_orders(pair.as_deref()).await
355            }
356
357            // Funding
358            "withdraw_fee" => {
359                let currency = Self::get_str(&args, "currency").unwrap_or_default();
360                let network = Self::get_str(&args, "network");
361                self.handle_withdraw_fee(&currency, network.as_deref())
362                    .await
363            }
364            "withdraw" => {
365                let acknowledged = Self::get_bool(&args, "acknowledged");
366                if let Err(msg) =
367                    self.safety.check_operation(&ServiceGroup::Funding, acknowledged)
368                {
369                    return Ok(Self::error_result(msg));
370                }
371                let currency = Self::get_str(&args, "currency").unwrap_or_default();
372                let amount = Self::get_num(&args, "amount").unwrap_or(0.0);
373                let address = Self::get_str(&args, "address").unwrap_or_default();
374                let to_username = Self::get_bool(&args, "to_username");
375                let memo = Self::get_str(&args, "memo");
376                let network = Self::get_str(&args, "network");
377                let callback_url = Self::get_str(&args, "callback_url");
378                self.handle_withdraw(
379                    &currency,
380                    amount,
381                    &address,
382                    to_username,
383                    memo.as_deref(),
384                    network.as_deref(),
385                    callback_url.as_deref(),
386                )
387                .await
388            }
389
390            // Paper
391            "paper_init" => {
392                let idr = Self::get_num(&args, "idr");
393                let btc = Self::get_num(&args, "btc");
394                self.handle_paper_init(idr, btc).await
395            }
396            "paper_reset" => self.handle_paper_reset().await,
397            "paper_balance" => self.handle_paper_balance().await,
398            "paper_buy" | "paper_sell" => {
399                let pair = helpers::normalize_pair(
400                    &Self::get_str(&args, "pair")
401                        .unwrap_or_else(|| "btc_idr".into())
402                );
403                let price = Self::get_num(&args, "price");
404                let amount = Self::get_num(&args, "amount");
405                let idr = Self::get_num(&args, "idr");
406                let side = if name == "paper_buy" {
407                    "buy"
408                } else {
409                    "sell"
410                };
411                self.handle_paper_trade(side, &pair, price, amount, idr).await
412            }
413            "paper_orders" => self.handle_paper_orders().await,
414            "paper_cancel" => {
415                let order_id = match Self::get_num(&args, "order_id") {
416                    Some(v) => {
417                        if v.fract() != 0.0 {
418                            return Ok(Self::validation_error_result(format!("order_id must be a whole number, got {}", v)));
419                        }
420                        v as u64
421                    }
422                    None => return Ok(Self::validation_error_result("Missing required parameter: order_id".into())),
423                };
424                self.handle_paper_cancel(order_id).await
425            }
426            "paper_cancel_all" => self.handle_paper_cancel_all().await,
427            "paper_history" => self.handle_paper_history().await,
428            "paper_status" => self.handle_paper_status().await,
429            "paper_fill" => {
430                let order_id = Self::get_num(&args, "order_id");
431                let price = Self::get_num(&args, "price");
432                let all = Self::get_bool(&args, "all");
433                let fetch = Self::get_bool(&args, "fetch");
434                self.handle_paper_fill(order_id, price, all, fetch).await
435            }
436            "paper_check_fills" => {
437                let prices = Self::get_str(&args, "prices");
438                let fetch = Self::get_bool(&args, "fetch");
439                self.handle_paper_check_fills(prices.as_deref(), fetch).await
440            }
441
442            // Auth
443            "auth_show" => self.handle_auth_show().await,
444            "auth_test" => self.handle_auth_test().await,
445
446            _ => Self::error_result(format!("Unknown tool: {}", name)),
447        };
448
449        Ok(result)
450    }
451}
452
453fn all_tools(mcp: &IndodaxMcp) -> Vec<Tool> {
454    let mut tools: Vec<Tool> = Vec::new();
455
456    if mcp.is_group_enabled(&ServiceGroup::Market) {
457        tools.extend(market::market_tools());
458    }
459    if mcp.is_group_enabled(&ServiceGroup::Account) {
460        tools.extend(account::account_tools());
461    }
462    if mcp.is_group_enabled(&ServiceGroup::Trade) {
463        if let Err(msg) = mcp.safety.check_group(&ServiceGroup::Trade) {
464            eprintln!("[MCP] Warning: {}", msg);
465        }
466        tools.extend(trade::trade_tools());
467    }
468    if mcp.is_group_enabled(&ServiceGroup::Funding) {
469        if let Err(msg) = mcp.safety.check_group(&ServiceGroup::Funding) {
470            eprintln!("[MCP] Warning: {}", msg);
471        }
472        tools.extend(funding::funding_tools());
473    }
474    if mcp.is_group_enabled(&ServiceGroup::Paper) {
475        tools.extend(paper::paper_tools());
476    }
477    if mcp.is_group_enabled(&ServiceGroup::Auth) {
478        tools.extend(auth::auth_tools());
479    }
480
481    tools
482}
483
484impl IndodaxMcp {
485    pub fn all_tools(&self) -> Vec<Tool> {
486        all_tools(self)
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use serde_json::json;
494    use crate::client::IndodaxClient;
495    use crate::config::IndodaxConfig;
496
497fn test_mcp() -> IndodaxMcp {
498    let client = IndodaxClient::new(None).unwrap();
499        let config = IndodaxConfig::default();
500        let safety = SafetyConfig::new(false);
501        let groups = vec![ServiceGroup::Market, ServiceGroup::Paper];
502        IndodaxMcp::new(client, config, safety, groups)
503    }
504
505    #[test]
506    fn test_get_str() {
507        let mut args = Map::new();
508        args.insert("name".into(), json!("test_value"));
509        assert_eq!(IndodaxMcp::get_str(&args, "name"), Some("test_value".into()));
510        assert_eq!(IndodaxMcp::get_str(&args, "missing"), None);
511    }
512
513    #[test]
514    fn test_get_num_from_number() {
515        let mut args = Map::new();
516        args.insert("price".into(), json!(100.5));
517        assert_eq!(IndodaxMcp::get_num(&args, "price"), Some(100.5));
518    }
519
520    #[test]
521    fn test_get_num_from_string() {
522        let mut args = Map::new();
523        args.insert("amount".into(), json!("50.25"));
524        assert_eq!(IndodaxMcp::get_num(&args, "amount"), Some(50.25));
525    }
526
527    #[test]
528    fn test_get_num_missing() {
529        let args = Map::new();
530        assert_eq!(IndodaxMcp::get_num(&args, "missing"), None);
531    }
532
533    #[test]
534    fn test_get_bool_true() {
535        let mut args = Map::new();
536        args.insert("acknowledged".into(), json!(true));
537        assert!(IndodaxMcp::get_bool(&args, "acknowledged"));
538    }
539
540    #[test]
541    fn test_get_bool_false() {
542        let mut args = Map::new();
543        args.insert("flag".into(), json!(false));
544        assert!(!IndodaxMcp::get_bool(&args, "flag"));
545    }
546
547    #[test]
548    fn test_get_bool_missing_defaults_false() {
549        let args = Map::new();
550        assert!(!IndodaxMcp::get_bool(&args, "missing"));
551    }
552
553    #[test]
554    fn test_get_opt_bool() {
555        let mut args = Map::new();
556        args.insert("true_val".into(), json!(true));
557        args.insert("false_val".into(), json!(false));
558        
559        assert_eq!(IndodaxMcp::get_opt_bool(&args, "true_val"), Some(true));
560        assert_eq!(IndodaxMcp::get_opt_bool(&args, "false_val"), Some(false));
561        assert_eq!(IndodaxMcp::get_opt_bool(&args, "missing"), None);
562    }
563
564    #[test]
565    fn test_tool_def_creates_tool() {
566        let properties = serde_json::json!({
567            "pair": {
568                "type": "string",
569                "description": "Trading pair"
570            }
571        });
572        let tool = IndodaxMcp::tool_def("test_tool", "A test tool", properties, vec!["pair"]);
573        assert_eq!(tool.name.to_string(), "test_tool");
574        assert!(tool.description.is_some_and(|d| d.as_ref() == "A test tool"));
575    }
576
577    #[test]
578    fn test_tool_def_no_required_params() {
579        let properties = serde_json::json!({});
580        let tool = IndodaxMcp::tool_def("empty_tool", "No params", properties, vec![]);
581        assert_eq!(tool.name.to_string(), "empty_tool");
582    }
583
584    #[test]
585    fn test_str_param() {
586        let param = IndodaxMcp::str_param("A test string", false, Some("default"));
587        assert_eq!(param["type"], "string");
588        assert_eq!(param["default"], "default");
589    }
590
591    #[test]
592    fn test_num_param() {
593        let param = IndodaxMcp::num_param("A test number", true);
594        assert_eq!(param["type"], "number");
595    }
596
597    #[test]
598    fn test_bool_param() {
599        let param = IndodaxMcp::bool_param("A test boolean");
600        assert_eq!(param["type"], "boolean");
601    }
602
603    #[test]
604    fn test_mcp_is_group_enabled() {
605        let mcp = test_mcp();
606        assert!(mcp.is_group_enabled(&ServiceGroup::Market));
607        assert!(mcp.is_group_enabled(&ServiceGroup::Paper));
608        assert!(!mcp.is_group_enabled(&ServiceGroup::Trade));
609        assert!(!mcp.is_group_enabled(&ServiceGroup::Account));
610        assert!(!mcp.is_group_enabled(&ServiceGroup::Funding));
611        assert!(!mcp.is_group_enabled(&ServiceGroup::Auth));
612    }
613
614    #[test]
615    fn test_ok_result() {
616        let result = IndodaxMcp::ok_result("success".into());
617        assert_eq!(result.is_error, Some(false));
618    }
619
620    #[test]
621    fn test_error_result_contains_error() {
622        let result = IndodaxMcp::error_result("something failed".into());
623        assert_eq!(result.is_error, Some(true));
624    }
625
626    #[test]
627    fn test_validation_error_result_contains_validation_type() {
628        let result = IndodaxMcp::validation_error_result("bad input".into());
629        assert_eq!(result.is_error, Some(true));
630        let content = &result.content;
631        let text = content.first().and_then(|c| c.as_text()).map(|t| t.text.as_str()).unwrap_or("");
632        assert!(text.contains("validation_error"));
633    }
634
635    #[test]
636    fn test_json_result() {
637        let value = json!({"key": "value", "num": 42});
638        let result = IndodaxMcp::json_result(value);
639        assert_eq!(result.is_error, Some(false));
640    }
641
642    #[test]
643    fn test_all_tools_respects_groups() {
644        let mcp = test_mcp();
645        let tools = mcp.all_tools();
646        let names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
647        assert!(names.contains(&"server_time".to_string()));
648        assert!(names.contains(&"ticker".to_string()));
649        assert!(names.contains(&"paper_init".to_string()));
650        assert!(names.contains(&"paper_balance".to_string()));
651        assert!(!names.contains(&"buy_order".to_string()));
652        assert!(!names.contains(&"sell_order".to_string()));
653        assert!(!names.contains(&"account_info".to_string()));
654    }
655}