Skip to main content

indodax_cli/mcp/tools/
mod.rs

1pub mod account;
2pub mod auth;
3pub mod funding;
4pub mod market;
5pub mod paper;
6pub mod trade;
7
8use std::sync::Arc;
9
10use tokio::sync::Mutex;
11use serde_json::{Map, Value};
12
13use rmcp::model::{
14    CallToolRequestParams, CallToolResult, Content, Implementation, InitializeResult,
15    ListToolsResult, PaginatedRequestParams, ServerCapabilities, Tool,
16};
17use rmcp::service::{RequestContext, RoleServer};
18use rmcp::ErrorData as McpError;
19
20use crate::commands::helpers;
21use crate::config::IndodaxConfig;
22use crate::errors::IndodaxError;
23use crate::mcp::safety::SafetyConfig;
24use crate::mcp::service::ServiceGroup;
25use crate::client::IndodaxClient;
26
27/// The MCP server exposing Indodax trading functionality as tools.
28#[derive(Debug, Clone)]
29pub struct IndodaxMcp {
30    pub client: Arc<IndodaxClient>,
31    pub config: Arc<Mutex<IndodaxConfig>>,
32    pub safety: SafetyConfig,
33    pub enabled_groups: Vec<ServiceGroup>,
34}
35
36impl IndodaxMcp {
37    pub fn new(
38        client: IndodaxClient,
39        config: IndodaxConfig,
40        safety: SafetyConfig,
41        enabled_groups: Vec<ServiceGroup>,
42    ) -> Self {
43        Self {
44            client: Arc::new(client),
45            config: Arc::new(Mutex::new(config)),
46            safety,
47            enabled_groups,
48        }
49    }
50
51    pub fn is_group_enabled(&self, group: &ServiceGroup) -> bool {
52        self.enabled_groups.contains(group)
53    }
54
55    pub fn str_param(description: &str, _required: bool, default_: Option<&str>) -> Value {
56        let mut schema = serde_json::json!({
57            "type": "string",
58            "description": description,
59        });
60        if let Some(d) = default_ {
61            schema["default"] = Value::String(d.to_string());
62        }
63        schema
64    }
65
66    pub fn num_param(description: &str, _required: bool) -> Value {
67        serde_json::json!({
68            "type": "number",
69            "description": description,
70        })
71    }
72
73    pub fn bool_param(description: &str) -> Value {
74        serde_json::json!({
75            "type": "boolean",
76            "description": description,
77        })
78    }
79
80    /// Helper to get account info from Indodax API.
81    pub async fn get_account_info(&self) -> Result<Value, IndodaxError> {
82        self.client
83            .private_post_v1::<Value>("getInfo", &std::collections::HashMap::new())
84            .await
85    }
86
87    pub fn tool_def(name: &str, description: &str, properties: Value, required: Vec<&str>) -> Tool {
88        let mut schema = Map::new();
89        schema.insert("type".to_string(), Value::String("object".to_string()));
90
91        if let Value::Object(props) = properties {
92            if !props.is_empty() {
93                schema.insert("properties".to_string(), Value::Object(props));
94            }
95        }
96
97        if !required.is_empty() {
98            let req_values: Vec<Value> = required
99                .iter()
100                .map(|s| Value::String(s.to_string()))
101                .collect();
102            schema.insert("required".to_string(), Value::Array(req_values));
103        }
104
105        Tool::new(name.to_string(), description.to_string(), Arc::new(schema))
106    }
107
108    pub fn get_str(args: &Map<String, Value>, name: &str) -> Option<String> {
109        args.get(name)
110            .and_then(|v| v.as_str())
111            .map(|s| s.to_string())
112    }
113
114    pub fn get_num(args: &Map<String, Value>, name: &str) -> Option<f64> {
115        args.get(name).and_then(|v| {
116            v.as_f64()
117                .or_else(|| v.as_str().and_then(|s| s.parse::<f64>().ok()))
118        })
119    }
120
121    pub fn get_bool(args: &Map<String, Value>, name: &str) -> bool {
122        Self::get_opt_bool(args, name).unwrap_or(false)
123    }
124
125    pub fn get_opt_bool(args: &Map<String, Value>, name: &str) -> Option<bool> {
126        args.get(name).and_then(|v| v.as_bool())
127    }
128
129    pub fn ok_result(text: String) -> CallToolResult {
130        CallToolResult::success(vec![Content::text(text)])
131    }
132
133    pub fn error_result(text: String) -> CallToolResult {
134        let envelope = serde_json::json!({
135            "error": true,
136            "message": text,
137            "error_type": "mcp_error",
138        });
139        CallToolResult::error(vec![Content::text(
140            serde_json::to_string_pretty(&envelope).unwrap_or(text),
141        )])
142    }
143
144    pub fn validation_error_result(text: String) -> CallToolResult {
145        let envelope = serde_json::json!({
146            "error": true,
147            "message": text,
148            "error_type": "validation_error",
149        });
150        CallToolResult::error(vec![Content::text(
151            serde_json::to_string_pretty(&envelope).unwrap_or(text),
152        )])
153    }
154
155    pub fn error_from_indodax(err: &IndodaxError) -> CallToolResult {
156        let envelope = serde_json::json!({
157            "error": true,
158            "message": err.to_string(),
159            "error_type": err.category(),
160        });
161        CallToolResult::error(vec![Content::text(
162            serde_json::to_string_pretty(&envelope).unwrap_or_else(|_| err.to_string()),
163        )])
164    }
165
166    pub fn json_result(value: Value) -> CallToolResult {
167        let text = serde_json::to_string_pretty(&value).unwrap_or_default();
168        Self::ok_result(text)
169    }
170}
171
172impl rmcp::handler::server::ServerHandler for IndodaxMcp {
173    fn get_info(&self) -> InitializeResult {
174        InitializeResult::new(
175            ServerCapabilities::builder()
176                .enable_tools()
177                .build(),
178        )
179        .with_server_info(Implementation::new(
180            "indodax-cli",
181            env!("CARGO_PKG_VERSION"),
182        ))
183    }
184
185    async fn list_tools(
186        &self,
187        _request: Option<PaginatedRequestParams>,
188        _context: RequestContext<RoleServer>,
189    ) -> Result<ListToolsResult, McpError> {
190        let tools = self.all_tools();
191        Ok(ListToolsResult::with_all_items(tools))
192    }
193
194    async fn call_tool(
195        &self,
196        request: CallToolRequestParams,
197        _context: RequestContext<RoleServer>,
198    ) -> Result<CallToolResult, McpError> {
199        let name = request.name.to_string();
200        let args = request.arguments.unwrap_or_default();
201
202        let result = match name.as_str() {
203            // Market (public, no auth needed)
204            "server_time" => self.handle_server_time().await,
205            "ticker" => {
206                let pair = helpers::normalize_pair(
207                    &Self::get_str(&args, "pair").unwrap_or_else(|| "btc_idr".into())
208                );
209                self.handle_ticker(&pair).await
210            }
211            "ticker_all" => self.handle_ticker_all().await,
212            "pairs" => self.handle_pairs().await,
213            "summaries" => self.handle_summaries().await,
214            "orderbook" => {
215                let pair = helpers::normalize_pair(
216                    &Self::get_str(&args, "pair").unwrap_or_else(|| "btc_idr".into())
217                );
218                self.handle_orderbook(&pair).await
219            }
220            "trades" => {
221                let pair = helpers::normalize_pair(
222                    &Self::get_str(&args, "pair").unwrap_or_else(|| "btc_idr".into())
223                );
224                self.handle_trades(&pair).await
225            }
226            "ohlc" => {
227                let symbol = helpers::normalize_pair(
228                    &Self::get_str(&args, "symbol").unwrap_or_else(|| "btc_idr".into())
229                ).replace('_', "").to_uppercase();
230                let timeframe =
231                    Self::get_str(&args, "timeframe").unwrap_or_else(|| "60".into());
232                let from = Self::get_num(&args, "from");
233                let to = Self::get_num(&args, "to");
234                self.handle_ohlc(&symbol, &timeframe, from, to).await
235            }
236            "price_increments" => self.handle_price_increments().await,
237
238            // Trade
239            "buy_order" => {
240                let acknowledged = Self::get_bool(&args, "acknowledged");
241                if let Err(msg) =
242                    self.safety.check_operation(&ServiceGroup::Trade, acknowledged)
243                {
244                    return Ok(Self::error_result(msg));
245                }
246                let pair = helpers::normalize_pair(
247                    &Self::get_str(&args, "pair").unwrap_or_default()
248                );
249                let idr = Self::get_num(&args, "idr").unwrap_or(0.0);
250                let price = Self::get_num(&args, "price");
251                self.handle_buy_order(&pair, idr, price).await
252            }
253            "sell_order" => {
254                let acknowledged = Self::get_bool(&args, "acknowledged");
255                if let Err(msg) =
256                    self.safety.check_operation(&ServiceGroup::Trade, acknowledged)
257                {
258                    return Ok(Self::error_result(msg));
259                }
260                let pair = helpers::normalize_pair(
261                    &Self::get_str(&args, "pair").unwrap_or_default()
262                );
263                let price = Self::get_num(&args, "price");
264                let amount = match Self::get_num(&args, "amount") {
265                    Some(v) if v > 0.0 => v,
266                    Some(v) => return Ok(Self::validation_error_result(format!("Amount must be positive, got {}", v))),
267                    None => return Ok(Self::validation_error_result("Missing required parameter: amount".into())),
268                };
269                let order_type = Self::get_str(&args, "order_type").unwrap_or_else(|| "limit".into());
270                self.handle_sell_order(&pair, price, amount, &order_type).await
271            }
272
273            // Account
274            "account_info" => self.handle_account_info().await,
275            "balance" => self.handle_balance().await,
276            "open_orders" => {
277                let pair = Self::get_str(&args, "pair").map(|p| helpers::normalize_pair(&p));
278                self.handle_open_orders(pair.as_deref()).await
279            }
280            "order_history" => {
281                let symbol = helpers::normalize_pair(
282                    &Self::get_str(&args, "symbol").unwrap_or_else(|| "btc_idr".into())
283                );
284                let limit = Self::get_num(&args, "limit");
285                self.handle_order_history(&symbol, limit).await
286            }
287            "trade_history" => {
288                let symbol = helpers::normalize_pair(
289                    &Self::get_str(&args, "symbol").unwrap_or_else(|| "btc_idr".into())
290                );
291                let limit = Self::get_num(&args, "limit");
292                self.handle_trade_history(&symbol, limit).await
293            }
294            "get_order" => {
295                let order_id = match Self::get_num(&args, "order_id") {
296                    Some(v) => {
297                        if v.fract() != 0.0 {
298                            return Ok(Self::validation_error_result(format!("order_id must be a whole number, got {}", v)));
299                        }
300                        v
301                    }
302                    None => return Ok(Self::validation_error_result("Missing required parameter: order_id".into())),
303                };
304                let pair = match Self::get_str(&args, "pair") {
305                    Some(v) => helpers::normalize_pair(&v),
306                    None => return Ok(Self::validation_error_result("Missing required parameter: pair".into())),
307                };
308                self.handle_get_order(order_id, &pair).await
309            }
310            "trans_history" => self.handle_trans_history().await,
311            "cancel_order" => {
312                let acknowledged = Self::get_bool(&args, "acknowledged");
313                if let Err(msg) =
314                    self.safety.check_operation(&ServiceGroup::Trade, acknowledged)
315                {
316                    return Ok(Self::error_result(msg));
317                }
318                let order_id = match Self::get_num(&args, "order_id") {
319                    Some(v) => {
320                        if v.fract() != 0.0 {
321                            return Ok(Self::validation_error_result(format!("order_id must be a whole number, got {}", v)));
322                        }
323                        v
324                    }
325                    None => return Ok(Self::validation_error_result("Missing required parameter: order_id".into())),
326                };
327                let pair = match Self::get_str(&args, "pair") {
328                    Some(v) => helpers::normalize_pair(&v),
329                    None => return Ok(Self::validation_error_result("Missing required parameter: pair".into())),
330                };
331                let order_type = Self::get_str(&args, "order_type").unwrap_or_default();
332                self.handle_cancel_order(order_id, &pair, &order_type).await
333            }
334            "cancel_all_orders" => {
335                let acknowledged = Self::get_bool(&args, "acknowledged");
336                if let Err(msg) =
337                    self.safety.check_operation(&ServiceGroup::Trade, acknowledged)
338                {
339                    return Ok(Self::error_result(msg));
340                }
341                let pair = Self::get_str(&args, "pair")
342                    .map(|p| helpers::normalize_pair(&p));
343                self.handle_cancel_all_orders(pair.as_deref()).await
344            }
345
346            // Funding
347            "withdraw_fee" => {
348                let currency = Self::get_str(&args, "currency").unwrap_or_default();
349                let network = Self::get_str(&args, "network");
350                self.handle_withdraw_fee(&currency, network.as_deref())
351                    .await
352            }
353            "withdraw" => {
354                let acknowledged = Self::get_bool(&args, "acknowledged");
355                if let Err(msg) =
356                    self.safety.check_operation(&ServiceGroup::Funding, acknowledged)
357                {
358                    return Ok(Self::error_result(msg));
359                }
360                let currency = Self::get_str(&args, "currency").unwrap_or_default();
361                let amount = Self::get_num(&args, "amount").unwrap_or(0.0);
362                let address = Self::get_str(&args, "address").unwrap_or_default();
363                let to_username = Self::get_bool(&args, "to_username");
364                let memo = Self::get_str(&args, "memo");
365                let network = Self::get_str(&args, "network");
366                let callback_url = Self::get_str(&args, "callback_url");
367                self.handle_withdraw(
368                    &currency,
369                    amount,
370                    &address,
371                    to_username,
372                    memo.as_deref(),
373                    network.as_deref(),
374                    callback_url.as_deref(),
375                )
376                .await
377            }
378
379            // Paper
380            "paper_init" => {
381                let idr = Self::get_num(&args, "idr");
382                let btc = Self::get_num(&args, "btc");
383                self.handle_paper_init(idr, btc).await
384            }
385            "paper_reset" => self.handle_paper_reset().await,
386            "paper_balance" => self.handle_paper_balance().await,
387            "paper_buy" | "paper_sell" => {
388                let pair = helpers::normalize_pair(
389                    &Self::get_str(&args, "pair")
390                        .unwrap_or_else(|| "btc_idr".into())
391                );
392                let price = Self::get_num(&args, "price");
393                let amount = Self::get_num(&args, "amount");
394                let idr = Self::get_num(&args, "idr");
395                let side = if name == "paper_buy" {
396                    "buy"
397                } else {
398                    "sell"
399                };
400                self.handle_paper_trade(side, &pair, price, amount, idr).await
401            }
402            "paper_orders" => self.handle_paper_orders().await,
403            "paper_cancel" => {
404                let order_id = match Self::get_num(&args, "order_id") {
405                    Some(v) => {
406                        if v.fract() != 0.0 {
407                            return Ok(Self::validation_error_result(format!("order_id must be a whole number, got {}", v)));
408                        }
409                        v as u64
410                    }
411                    None => return Ok(Self::validation_error_result("Missing required parameter: order_id".into())),
412                };
413                self.handle_paper_cancel(order_id).await
414            }
415            "paper_cancel_all" => self.handle_paper_cancel_all().await,
416            "paper_history" => self.handle_paper_history().await,
417            "paper_status" => self.handle_paper_status().await,
418            "paper_fill" => {
419                let order_id = Self::get_num(&args, "order_id");
420                let price = Self::get_num(&args, "price");
421                let all = Self::get_bool(&args, "all");
422                self.handle_paper_fill(order_id, price, all).await
423            }
424            "paper_check_fills" => {
425                let prices = Self::get_str(&args, "prices");
426                let fetch = Self::get_bool(&args, "fetch");
427                self.handle_paper_check_fills(prices.as_deref(), fetch).await
428            }
429
430            // Auth
431            "auth_show" => self.handle_auth_show().await,
432            "auth_test" => self.handle_auth_test().await,
433
434            _ => Self::error_result(format!("Unknown tool: {}", name)),
435        };
436
437        Ok(result)
438    }
439}
440
441fn all_tools(mcp: &IndodaxMcp) -> Vec<Tool> {
442    let mut tools: Vec<Tool> = Vec::new();
443
444    if mcp.is_group_enabled(&ServiceGroup::Market) {
445        tools.extend(market::market_tools());
446    }
447    if mcp.is_group_enabled(&ServiceGroup::Account) {
448        tools.extend(account::account_tools());
449    }
450    if mcp.is_group_enabled(&ServiceGroup::Trade) {
451        if let Err(msg) = mcp.safety.check_group(&ServiceGroup::Trade) {
452            eprintln!("[MCP] Warning: {}", msg);
453        }
454        tools.extend(trade::trade_tools());
455    }
456    if mcp.is_group_enabled(&ServiceGroup::Funding) {
457        if let Err(msg) = mcp.safety.check_group(&ServiceGroup::Funding) {
458            eprintln!("[MCP] Warning: {}", msg);
459        }
460        tools.extend(funding::funding_tools());
461    }
462    if mcp.is_group_enabled(&ServiceGroup::Paper) {
463        tools.extend(paper::paper_tools());
464    }
465    if mcp.is_group_enabled(&ServiceGroup::Auth) {
466        tools.extend(auth::auth_tools());
467    }
468
469    tools
470}
471
472impl IndodaxMcp {
473    pub fn all_tools(&self) -> Vec<Tool> {
474        all_tools(self)
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use serde_json::json;
482    use crate::client::IndodaxClient;
483    use crate::config::IndodaxConfig;
484
485fn test_mcp() -> IndodaxMcp {
486    let client = IndodaxClient::new(None).unwrap();
487        let config = IndodaxConfig::default();
488        let safety = SafetyConfig::new(false);
489        let groups = vec![ServiceGroup::Market, ServiceGroup::Paper];
490        IndodaxMcp::new(client, config, safety, groups)
491    }
492
493    #[test]
494    fn test_get_str() {
495        let mut args = Map::new();
496        args.insert("name".into(), json!("test_value"));
497        assert_eq!(IndodaxMcp::get_str(&args, "name"), Some("test_value".into()));
498        assert_eq!(IndodaxMcp::get_str(&args, "missing"), None);
499    }
500
501    #[test]
502    fn test_get_num_from_number() {
503        let mut args = Map::new();
504        args.insert("price".into(), json!(100.5));
505        assert_eq!(IndodaxMcp::get_num(&args, "price"), Some(100.5));
506    }
507
508    #[test]
509    fn test_get_num_from_string() {
510        let mut args = Map::new();
511        args.insert("amount".into(), json!("50.25"));
512        assert_eq!(IndodaxMcp::get_num(&args, "amount"), Some(50.25));
513    }
514
515    #[test]
516    fn test_get_num_missing() {
517        let args = Map::new();
518        assert_eq!(IndodaxMcp::get_num(&args, "missing"), None);
519    }
520
521    #[test]
522    fn test_get_bool_true() {
523        let mut args = Map::new();
524        args.insert("acknowledged".into(), json!(true));
525        assert!(IndodaxMcp::get_bool(&args, "acknowledged"));
526    }
527
528    #[test]
529    fn test_get_bool_false() {
530        let mut args = Map::new();
531        args.insert("flag".into(), json!(false));
532        assert!(!IndodaxMcp::get_bool(&args, "flag"));
533    }
534
535    #[test]
536    fn test_get_bool_missing_defaults_false() {
537        let args = Map::new();
538        assert!(!IndodaxMcp::get_bool(&args, "missing"));
539    }
540
541    #[test]
542    fn test_get_opt_bool() {
543        let mut args = Map::new();
544        args.insert("true_val".into(), json!(true));
545        args.insert("false_val".into(), json!(false));
546        
547        assert_eq!(IndodaxMcp::get_opt_bool(&args, "true_val"), Some(true));
548        assert_eq!(IndodaxMcp::get_opt_bool(&args, "false_val"), Some(false));
549        assert_eq!(IndodaxMcp::get_opt_bool(&args, "missing"), None);
550    }
551
552    #[test]
553    fn test_tool_def_creates_tool() {
554        let properties = serde_json::json!({
555            "pair": {
556                "type": "string",
557                "description": "Trading pair"
558            }
559        });
560        let tool = IndodaxMcp::tool_def("test_tool", "A test tool", properties, vec!["pair"]);
561        assert_eq!(tool.name.to_string(), "test_tool");
562        assert!(tool.description.map_or(false, |d| d.as_ref() == "A test tool"));
563    }
564
565    #[test]
566    fn test_tool_def_no_required_params() {
567        let properties = serde_json::json!({});
568        let tool = IndodaxMcp::tool_def("empty_tool", "No params", properties, vec![]);
569        assert_eq!(tool.name.to_string(), "empty_tool");
570    }
571
572    #[test]
573    fn test_str_param() {
574        let param = IndodaxMcp::str_param("A test string", false, Some("default"));
575        assert_eq!(param["type"], "string");
576        assert_eq!(param["default"], "default");
577    }
578
579    #[test]
580    fn test_num_param() {
581        let param = IndodaxMcp::num_param("A test number", true);
582        assert_eq!(param["type"], "number");
583    }
584
585    #[test]
586    fn test_bool_param() {
587        let param = IndodaxMcp::bool_param("A test boolean");
588        assert_eq!(param["type"], "boolean");
589    }
590
591    #[test]
592    fn test_mcp_is_group_enabled() {
593        let mcp = test_mcp();
594        assert!(mcp.is_group_enabled(&ServiceGroup::Market));
595        assert!(mcp.is_group_enabled(&ServiceGroup::Paper));
596        assert!(!mcp.is_group_enabled(&ServiceGroup::Trade));
597        assert!(!mcp.is_group_enabled(&ServiceGroup::Account));
598        assert!(!mcp.is_group_enabled(&ServiceGroup::Funding));
599        assert!(!mcp.is_group_enabled(&ServiceGroup::Auth));
600    }
601
602    #[test]
603    fn test_ok_result() {
604        let result = IndodaxMcp::ok_result("success".into());
605        assert_eq!(result.is_error, Some(false));
606    }
607
608    #[test]
609    fn test_error_result_contains_error() {
610        let result = IndodaxMcp::error_result("something failed".into());
611        assert_eq!(result.is_error, Some(true));
612    }
613
614    #[test]
615    fn test_validation_error_result_contains_validation_type() {
616        let result = IndodaxMcp::validation_error_result("bad input".into());
617        assert_eq!(result.is_error, Some(true));
618        let content = &result.content;
619        let text = content.first().and_then(|c| c.as_text()).map(|t| t.text.as_str()).unwrap_or("");
620        assert!(text.contains("validation_error"));
621    }
622
623    #[test]
624    fn test_json_result() {
625        let value = json!({"key": "value", "num": 42});
626        let result = IndodaxMcp::json_result(value);
627        assert_eq!(result.is_error, Some(false));
628    }
629
630    #[test]
631    fn test_all_tools_respects_groups() {
632        let mcp = test_mcp();
633        let tools = mcp.all_tools();
634        let names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
635        assert!(names.contains(&"server_time".to_string()));
636        assert!(names.contains(&"ticker".to_string()));
637        assert!(names.contains(&"paper_init".to_string()));
638        assert!(names.contains(&"paper_balance".to_string()));
639        assert!(!names.contains(&"buy_order".to_string()));
640        assert!(!names.contains(&"sell_order".to_string()));
641        assert!(!names.contains(&"account_info".to_string()));
642    }
643}