Skip to main content

hyper_agent_ai/tools/
mod.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use hyper_exchange::ExchangeClient;
6use hyper_risk::risk::{record_risk_alert, AccountState, OrderRequest, RiskGuard};
7use motosan_chat_tool::{Tool, ToolContext, ToolDef, ToolResult, Value};
8
9use hyper_agent_core::position_manager::PositionManager;
10
11pub struct GetPositionsTool {
12    position_manager: Arc<PositionManager>,
13}
14
15impl GetPositionsTool {
16    pub fn new(position_manager: Arc<PositionManager>) -> Self {
17        Self { position_manager }
18    }
19}
20
21impl Tool for GetPositionsTool {
22    fn def(&self) -> ToolDef {
23        ToolDef {
24            name: "get_positions".to_string(),
25            description: "Get current open positions and account equity".to_string(),
26            input_schema: serde_json::json!({"type": "object", "properties": {}}),
27        }
28    }
29
30    fn call(
31        &self,
32        _args: Value,
33        _ctx: &ToolContext,
34    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
35        let pm = self.position_manager.clone();
36        Box::pin(async move {
37            match pm.list_open().await {
38                Ok(positions) => {
39                    let count = positions.len();
40                    let pos_json: Vec<serde_json::Value> = positions
41                        .iter()
42                        .map(|p| {
43                            serde_json::json!({
44                                "id": p.id,
45                                "market": p.market,
46                                "side": p.side,
47                                "size": p.size,
48                                "entry_price": p.entry_price,
49                                "current_price": p.current_price,
50                                "pnl": p.pnl,
51                                "mode": p.mode,
52                            })
53                        })
54                        .collect();
55                    ToolResult::json(serde_json::json!({"positions": pos_json, "count": count}))
56                }
57                Err(e) => ToolResult::error(format!("Failed to list positions: {}", e)),
58            }
59        })
60    }
61}
62
63pub struct PlaceOrderTool {
64    risk_guard: RiskGuard,
65    account_state: AccountState,
66}
67
68impl PlaceOrderTool {
69    pub fn new(risk_guard: RiskGuard, account_state: AccountState) -> Self {
70        Self {
71            risk_guard,
72            account_state,
73        }
74    }
75}
76
77impl Tool for PlaceOrderTool {
78    fn def(&self) -> ToolDef {
79        ToolDef {
80            name: "place_order".to_string(),
81            description: "Submit a place order request".to_string(),
82            input_schema: serde_json::json!({
83                "type": "object",
84                "properties": {
85                    "asset": {"type": "number"},
86                    "is_buy": {"type": "boolean"},
87                    "price": {"type": ["string", "number"]},
88                    "size": {"type": ["string", "number"]}
89                },
90                "required": ["asset", "is_buy", "price", "size"]
91            }),
92        }
93    }
94
95    fn call(
96        &self,
97        args: Value,
98        _ctx: &ToolContext,
99    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
100        Box::pin(async move {
101            let order_req = order_request_from_args(&args);
102            match self.risk_guard.check_order(&order_req, &self.account_state) {
103                Ok(()) => ToolResult::json(serde_json::json!({
104                    "status": "accepted",
105                    "request": args,
106                })),
107                Err(violation) => {
108                    record_risk_alert(&violation);
109                    ToolResult::error(violation.to_string())
110                }
111            }
112        })
113    }
114}
115
116pub struct CancelOrderTool {
117    risk_guard: RiskGuard,
118    account_state: AccountState,
119}
120
121impl CancelOrderTool {
122    pub fn new(risk_guard: RiskGuard, account_state: AccountState) -> Self {
123        Self {
124            risk_guard,
125            account_state,
126        }
127    }
128}
129
130impl Tool for CancelOrderTool {
131    fn def(&self) -> ToolDef {
132        ToolDef {
133            name: "cancel_order".to_string(),
134            description: "Cancel an existing order by id".to_string(),
135            input_schema: serde_json::json!({
136                "type": "object",
137                "properties": {"order_id": {"type": ["string", "number"]}},
138                "required": ["order_id"]
139            }),
140        }
141    }
142
143    fn call(
144        &self,
145        args: Value,
146        _ctx: &ToolContext,
147    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
148        Box::pin(async move {
149            let synthetic_request = OrderRequest {
150                symbol: "unknown".to_string(),
151                side: "sell".to_string(),
152                size: 0.0,
153                price: 0.0,
154            };
155            match self
156                .risk_guard
157                .check_order(&synthetic_request, &self.account_state)
158            {
159                Ok(()) => ToolResult::json(serde_json::json!({
160                    "status": "accepted",
161                    "request": args,
162                })),
163                Err(violation) => {
164                    record_risk_alert(&violation);
165                    ToolResult::error(violation.to_string())
166                }
167            }
168        })
169    }
170}
171
172pub struct DoNothingTool;
173
174impl Tool for DoNothingTool {
175    fn def(&self) -> ToolDef {
176        ToolDef {
177            name: "do_nothing".to_string(),
178            description: "Return hold/no-action decision".to_string(),
179            input_schema: serde_json::json!({
180                "type": "object",
181                "properties": {"reason": {"type": "string"}}
182            }),
183        }
184    }
185
186    fn call(
187        &self,
188        args: Value,
189        _ctx: &ToolContext,
190    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
191        Box::pin(async move {
192            ToolResult::json(serde_json::json!({
193                "status": "noop",
194                "reason": args.get("reason").and_then(|v| v.as_str()).unwrap_or("No action")
195            }))
196        })
197    }
198}
199
200pub struct SetStopLossTool;
201
202impl Tool for SetStopLossTool {
203    fn def(&self) -> ToolDef {
204        ToolDef {
205            name: "set_stop_loss".to_string(),
206            description: "Set a stop-loss trigger on an existing position. Triggers a market close when the price reaches the specified trigger price.".to_string(),
207            input_schema: serde_json::json!({
208                "type": "object",
209                "properties": {
210                    "symbol": {"type": "string", "description": "Trading pair symbol, e.g. BTC-PERP"},
211                    "trigger_price": {"type": "number", "description": "Price at which the stop-loss triggers"}
212                },
213                "required": ["symbol", "trigger_price"]
214            }),
215        }
216    }
217
218    fn call(
219        &self,
220        args: Value,
221        _ctx: &ToolContext,
222    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
223        Box::pin(async move {
224            let symbol = args
225                .get("symbol")
226                .and_then(|v| v.as_str())
227                .unwrap_or("unknown");
228            let trigger_price = args
229                .get("trigger_price")
230                .and_then(|v| v.as_f64())
231                .unwrap_or(0.0);
232            ToolResult::json(serde_json::json!({
233                "status": "accepted",
234                "trigger_type": "stop_loss",
235                "symbol": symbol,
236                "trigger_price": trigger_price
237            }))
238        })
239    }
240}
241
242pub struct SetTakeProfitTool;
243
244impl Tool for SetTakeProfitTool {
245    fn def(&self) -> ToolDef {
246        ToolDef {
247            name: "set_take_profit".to_string(),
248            description: "Set a take-profit trigger on an existing position. Triggers a market close when the price reaches the specified trigger price.".to_string(),
249            input_schema: serde_json::json!({
250                "type": "object",
251                "properties": {
252                    "symbol": {"type": "string", "description": "Trading pair symbol, e.g. BTC-PERP"},
253                    "trigger_price": {"type": "number", "description": "Price at which the take-profit triggers"}
254                },
255                "required": ["symbol", "trigger_price"]
256            }),
257        }
258    }
259
260    fn call(
261        &self,
262        args: Value,
263        _ctx: &ToolContext,
264    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
265        Box::pin(async move {
266            let symbol = args
267                .get("symbol")
268                .and_then(|v| v.as_str())
269                .unwrap_or("unknown");
270            let trigger_price = args
271                .get("trigger_price")
272                .and_then(|v| v.as_f64())
273                .unwrap_or(0.0);
274            ToolResult::json(serde_json::json!({
275                "status": "accepted",
276                "trigger_type": "take_profit",
277                "symbol": symbol,
278                "trigger_price": trigger_price
279            }))
280        })
281    }
282}
283
284pub struct GetOpenOrdersTool {
285    is_mainnet: bool,
286}
287
288impl GetOpenOrdersTool {
289    pub fn new(is_mainnet: bool) -> Self {
290        Self { is_mainnet }
291    }
292}
293
294impl Tool for GetOpenOrdersTool {
295    fn def(&self) -> ToolDef {
296        ToolDef {
297            name: "get_open_orders".to_string(),
298            description: "List all open (pending/resting) orders. Optionally filter by a specific trading pair symbol. Returns order details including price, size, side, and order ID.".to_string(),
299            input_schema: serde_json::json!({
300                "type": "object",
301                "properties": {
302                    "symbol": {"type": "string", "description": "Optional trading pair symbol to filter by, e.g. BTC-PERP"}
303                }
304            }),
305        }
306    }
307
308    fn call(
309        &self,
310        args: Value,
311        _ctx: &ToolContext,
312    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
313        Box::pin(async move {
314            let symbol_filter = args.get("symbol").and_then(|v| v.as_str());
315            match fetch_open_orders(symbol_filter, self.is_mainnet).await {
316                Ok(info) => ToolResult::json(info),
317                Err(e) => ToolResult::error(format!("Failed to fetch open orders: {}", e)),
318            }
319        })
320    }
321}
322
323/// Fetch open orders from Hyperliquid public API.
324///
325/// Requires a user address. Since this tool runs inside the agent loop,
326/// we return a placeholder result. The real execution happens in `execute_tool`.
327pub async fn fetch_open_orders(
328    symbol_filter: Option<&str>,
329    _is_mainnet: bool,
330) -> Result<serde_json::Value, String> {
331    // This is a stub — actual exchange queries require the user address,
332    // which is resolved at the agent_loop / execute_tool level.
333    // The tool definition is used by the LLM; execution is handled by execute_tool.
334    Ok(serde_json::json!({
335        "orders": [],
336        "symbol_filter": symbol_filter,
337        "message": "Open orders query accepted"
338    }))
339}
340
341pub struct CloseAllPositionsTool;
342
343impl Tool for CloseAllPositionsTool {
344    fn def(&self) -> ToolDef {
345        ToolDef {
346            name: "close_all_positions".to_string(),
347            description: "Emergency: close ALL open positions immediately with market orders. Use when circuit breaker triggers or emergency situations arise.".to_string(),
348            input_schema: serde_json::json!({
349                "type": "object",
350                "properties": {
351                    "reason": {"type": "string", "description": "Reason for the emergency close"}
352                },
353                "required": ["reason"]
354            }),
355        }
356    }
357
358    fn call(
359        &self,
360        args: Value,
361        _ctx: &ToolContext,
362    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
363        Box::pin(async move {
364            let reason = args
365                .get("reason")
366                .and_then(|v| v.as_str())
367                .unwrap_or("emergency");
368            ToolResult::json(serde_json::json!({
369                "status": "accepted",
370                "action": "close_all_positions",
371                "reason": reason
372            }))
373        })
374    }
375}
376
377pub struct GetFundingRateTool {
378    is_mainnet: bool,
379}
380
381impl GetFundingRateTool {
382    pub fn new(is_mainnet: bool) -> Self {
383        Self { is_mainnet }
384    }
385}
386
387impl Tool for GetFundingRateTool {
388    fn def(&self) -> ToolDef {
389        ToolDef {
390            name: "get_funding_rate".to_string(),
391            description: "Get current funding rate for a perpetual contract symbol. Use to assess overnight holding cost.".to_string(),
392            input_schema: serde_json::json!({
393                "type": "object",
394                "properties": {
395                    "symbol": {"type": "string", "description": "Trading pair symbol, e.g. BTC-PERP"}
396                },
397                "required": ["symbol"]
398            }),
399        }
400    }
401
402    fn call(
403        &self,
404        args: Value,
405        _ctx: &ToolContext,
406    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
407        Box::pin(async move {
408            let symbol = match args.get("symbol").and_then(|v| v.as_str()) {
409                Some(s) => s,
410                None => return ToolResult::error("Missing required parameter: symbol"),
411            };
412
413            match fetch_funding_rate(symbol, self.is_mainnet).await {
414                Ok(info) => ToolResult::json(info),
415                Err(e) => ToolResult::error(format!("Failed to fetch funding rate: {}", e)),
416            }
417        })
418    }
419}
420
421/// Fetch the current funding rate for a symbol from Hyperliquid public API.
422///
423/// Returns a JSON value with funding rate details.
424pub async fn fetch_funding_rate(
425    symbol: &str,
426    is_mainnet: bool,
427) -> Result<serde_json::Value, String> {
428    let client = ExchangeClient::new(is_mainnet);
429
430    // Hyperliquid uses the metaAndAssetCtxs endpoint to get funding rates
431    let request = serde_json::json!({
432        "type": "metaAndAssetCtxs"
433    });
434
435    let response = client
436        .post_info(request)
437        .await
438        .map_err(|e| format!("API error: {}", e))?;
439
440    // Parse the response: [metaObj, [assetCtx, ...]]
441    let arr = response
442        .as_array()
443        .ok_or_else(|| "Invalid response format".to_string())?;
444
445    if arr.len() < 2 {
446        return Err("Incomplete response from API".to_string());
447    }
448
449    let meta_obj = &arr[0];
450    let asset_ctxs = arr[1]
451        .as_array()
452        .ok_or_else(|| "Invalid asset contexts format".to_string())?;
453
454    let universe = meta_obj
455        .get("universe")
456        .and_then(|u| u.as_array())
457        .ok_or_else(|| "Missing universe in meta".to_string())?;
458
459    // Strip suffix to get coin name
460    let coin = symbol
461        .replace("-PERP", "")
462        .replace("-USDC", "")
463        .replace("-USD", "");
464
465    let idx = universe
466        .iter()
467        .position(|item| item.get("name").and_then(|n| n.as_str()) == Some(&coin))
468        .ok_or_else(|| format!("Symbol {} not found", symbol))?;
469
470    if idx >= asset_ctxs.len() {
471        return Err(format!("Asset context not found for {}", symbol));
472    }
473
474    let ctx = &asset_ctxs[idx];
475
476    let funding_rate = ctx
477        .get("funding")
478        .and_then(|v| v.as_str())
479        .and_then(|s| s.parse::<f64>().ok())
480        .unwrap_or(0.0);
481
482    let mark_price = ctx
483        .get("markPx")
484        .and_then(|v| v.as_str())
485        .and_then(|s| s.parse::<f64>().ok())
486        .unwrap_or(0.0);
487
488    let open_interest = ctx
489        .get("openInterest")
490        .and_then(|v| v.as_str())
491        .and_then(|s| s.parse::<f64>().ok())
492        .unwrap_or(0.0);
493
494    // Annualized rate: funding is per 8h, so multiply by 3*365
495    let annualized_rate = funding_rate * 3.0 * 365.0;
496
497    Ok(serde_json::json!({
498        "symbol": symbol,
499        "funding_rate": funding_rate,
500        "funding_rate_pct": format!("{:.4}%", funding_rate * 100.0),
501        "annualized_rate_pct": format!("{:.2}%", annualized_rate * 100.0),
502        "mark_price": mark_price,
503        "open_interest": open_interest,
504        "direction": if funding_rate > 0.0 { "longs_pay_shorts" } else if funding_rate < 0.0 { "shorts_pay_longs" } else { "neutral" }
505    }))
506}
507
508pub struct GetMarketDataTool {
509    is_mainnet: bool,
510}
511
512impl GetMarketDataTool {
513    pub fn new(is_mainnet: bool) -> Self {
514        Self { is_mainnet }
515    }
516}
517
518impl Tool for GetMarketDataTool {
519    fn def(&self) -> ToolDef {
520        ToolDef {
521            name: "get_market_data".to_string(),
522            description: "Get orderbook (bids/asks) for a symbol. Use to check spread and liquidity before placing limit orders.".to_string(),
523            input_schema: serde_json::json!({
524                "type": "object",
525                "properties": {
526                    "symbol": {"type": "string", "description": "Trading pair symbol, e.g. BTC-PERP"},
527                    "depth": {"type": "integer", "description": "Number of orderbook levels to return (default 5)"}
528                },
529                "required": ["symbol"]
530            }),
531        }
532    }
533
534    fn call(
535        &self,
536        args: Value,
537        _ctx: &ToolContext,
538    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
539        Box::pin(async move {
540            let symbol = match args.get("symbol").and_then(|v| v.as_str()) {
541                Some(s) => s,
542                None => return ToolResult::error("Missing required parameter: symbol"),
543            };
544            let depth = args.get("depth").and_then(|v| v.as_u64()).unwrap_or(5) as usize;
545
546            match fetch_market_data(symbol, depth, self.is_mainnet).await {
547                Ok(info) => ToolResult::json(info),
548                Err(e) => ToolResult::error(format!("Failed to fetch market data: {}", e)),
549            }
550        })
551    }
552}
553
554/// Fetch L2 orderbook data for a symbol from Hyperliquid public API.
555///
556/// Returns top bids/asks and the spread. `depth` controls how many levels to return.
557pub async fn fetch_market_data(
558    symbol: &str,
559    depth: usize,
560    is_mainnet: bool,
561) -> Result<serde_json::Value, String> {
562    let client = ExchangeClient::new(is_mainnet);
563
564    // Strip suffix to get coin name for the API
565    let coin = symbol
566        .replace("-PERP", "")
567        .replace("-USDC", "")
568        .replace("-USD", "");
569
570    let request = serde_json::json!({
571        "type": "l2Book",
572        "coin": coin
573    });
574
575    let response = client
576        .post_info(request)
577        .await
578        .map_err(|e| format!("API error: {}", e))?;
579
580    // Parse the L2 book response: {"levels": [[{px, sz, n}, ...], [{px, sz, n}, ...]]}
581    let levels = response
582        .get("levels")
583        .and_then(|l| l.as_array())
584        .ok_or_else(|| "Invalid L2 book response format".to_string())?;
585
586    if levels.len() < 2 {
587        return Err("Incomplete L2 book response".to_string());
588    }
589
590    let raw_bids = levels[0]
591        .as_array()
592        .ok_or_else(|| "Invalid bids format".to_string())?;
593    let raw_asks = levels[1]
594        .as_array()
595        .ok_or_else(|| "Invalid asks format".to_string())?;
596
597    let depth = depth.min(20); // cap at 20 levels max
598
599    let parse_level = |entry: &serde_json::Value| -> Option<serde_json::Value> {
600        let px = entry
601            .get("px")
602            .and_then(|v| v.as_str())?
603            .parse::<f64>()
604            .ok()?;
605        let sz = entry
606            .get("sz")
607            .and_then(|v| v.as_str())?
608            .parse::<f64>()
609            .ok()?;
610        let n = entry.get("n").and_then(|v| v.as_u64()).unwrap_or(0);
611        Some(serde_json::json!({
612            "price": px,
613            "size": sz,
614            "num_orders": n
615        }))
616    };
617
618    let bids: Vec<serde_json::Value> = raw_bids
619        .iter()
620        .take(depth)
621        .filter_map(parse_level)
622        .collect();
623    let asks: Vec<serde_json::Value> = raw_asks
624        .iter()
625        .take(depth)
626        .filter_map(parse_level)
627        .collect();
628
629    // Calculate spread
630    let best_bid = bids
631        .first()
632        .and_then(|b| b.get("price"))
633        .and_then(|v| v.as_f64());
634    let best_ask = asks
635        .first()
636        .and_then(|a| a.get("price"))
637        .and_then(|v| v.as_f64());
638
639    let (spread, spread_pct, mid_price) = match (best_bid, best_ask) {
640        (Some(bid), Some(ask)) => {
641            let s = ask - bid;
642            let mid = (ask + bid) / 2.0;
643            let pct = if mid > 0.0 { s / mid * 100.0 } else { 0.0 };
644            (Some(s), Some(pct), Some(mid))
645        }
646        _ => (None, None, None),
647    };
648
649    Ok(serde_json::json!({
650        "symbol": symbol,
651        "bids": bids,
652        "asks": asks,
653        "best_bid": best_bid,
654        "best_ask": best_ask,
655        "mid_price": mid_price,
656        "spread": spread,
657        "spread_pct": spread_pct.map(|p| format!("{:.4}%", p)),
658        "depth": depth
659    }))
660}
661
662pub struct GetTradeHistoryTool {
663    is_mainnet: bool,
664}
665
666impl GetTradeHistoryTool {
667    pub fn new(is_mainnet: bool) -> Self {
668        Self { is_mainnet }
669    }
670}
671
672impl Tool for GetTradeHistoryTool {
673    fn def(&self) -> ToolDef {
674        ToolDef {
675            name: "get_trade_history".to_string(),
676            description: "Get recent trade fills (executed trades) for the current user. Optionally filter by symbol and limit the number of results.".to_string(),
677            input_schema: serde_json::json!({
678                "type": "object",
679                "properties": {
680                    "symbol": {"type": "string", "description": "Optional trading pair symbol to filter by, e.g. BTC-PERP"},
681                    "limit": {"type": "integer", "description": "Maximum number of fills to return (default: all)"}
682                }
683            }),
684        }
685    }
686
687    fn call(
688        &self,
689        args: Value,
690        _ctx: &ToolContext,
691    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
692        Box::pin(async move {
693            let symbol_filter = args.get("symbol").and_then(|v| v.as_str());
694            let limit = args
695                .get("limit")
696                .and_then(|v| v.as_u64())
697                .map(|l| l as usize);
698            match fetch_trade_history(symbol_filter, limit, self.is_mainnet).await {
699                Ok(info) => ToolResult::json(info),
700                Err(e) => ToolResult::error(format!("Failed to fetch trade history: {}", e)),
701            }
702        })
703    }
704}
705
706/// Fetch trade history (user fills) from Hyperliquid public API.
707///
708/// This is a stub — actual exchange queries require the user address,
709/// which is resolved at the agent_loop / execute_tool level.
710/// The tool definition is used by the LLM; execution is handled by execute_tool.
711pub async fn fetch_trade_history(
712    symbol_filter: Option<&str>,
713    limit: Option<usize>,
714    _is_mainnet: bool,
715) -> Result<serde_json::Value, String> {
716    Ok(serde_json::json!({
717        "fills": [],
718        "symbol_filter": symbol_filter,
719        "limit": limit,
720        "message": "Trade history query accepted"
721    }))
722}
723
724fn order_request_from_args(args: &Value) -> OrderRequest {
725    let symbol = args
726        .get("asset")
727        .and_then(|v| v.as_u64())
728        .map(|a| a.to_string())
729        .or_else(|| {
730            args.get("symbol")
731                .and_then(|v| v.as_str())
732                .map(|s| s.to_string())
733        })
734        .unwrap_or_else(|| "unknown".to_string());
735
736    let side = args
737        .get("is_buy")
738        .and_then(|v| v.as_bool())
739        .map(|b| if b { "buy" } else { "sell" }.to_string())
740        .or_else(|| {
741            args.get("side")
742                .and_then(|v| v.as_str())
743                .map(ToString::to_string)
744        })
745        .unwrap_or_else(|| "buy".to_string());
746
747    let size = args
748        .get("size")
749        .and_then(|v| {
750            v.as_str()
751                .and_then(|s| s.parse().ok())
752                .or_else(|| v.as_f64())
753        })
754        .unwrap_or(0.0);
755
756    let price = args
757        .get("price")
758        .and_then(|v| {
759            v.as_str()
760                .and_then(|s| s.parse().ok())
761                .or_else(|| v.as_f64())
762        })
763        .unwrap_or(0.0);
764
765    OrderRequest {
766        symbol,
767        side,
768        size,
769        price,
770    }
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776    use std::collections::VecDeque;
777    use std::sync::{Arc, Mutex};
778
779    use async_trait::async_trait;
780    use futures_util::stream;
781    use hyper_agent_core::position_manager::PositionManager;
782    use motosan_chat_agent::{AgentError, AgentLoop, LlmClient, LlmResponse, LlmStream, Message};
783    use motosan_chat_core::{Channel, ChatType, IncomingEvent, IncomingEventKind, Thread};
784    use motosan_chat_tool::ToolDef;
785    use tokio::sync::mpsc;
786
787    fn test_position_manager() -> Arc<PositionManager> {
788        Arc::new(PositionManager::new(":memory:").expect("in-memory position manager"))
789    }
790
791    #[derive(Clone)]
792    struct ScriptedLlm {
793        responses: Arc<Mutex<VecDeque<LlmResponse>>>,
794    }
795
796    #[async_trait]
797    impl LlmClient for ScriptedLlm {
798        async fn chat(
799            &self,
800            _messages: Vec<Message>,
801            _tools: &[ToolDef],
802        ) -> motosan_chat_tool::Result<LlmResponse> {
803            self.responses
804                .lock()
805                .expect("responses mutex")
806                .pop_front()
807                .ok_or_else(|| motosan_chat_tool::Error::new("no scripted response"))
808        }
809
810        async fn stream(&self, _messages: Vec<Message>) -> LlmStream {
811            Box::pin(stream::iter(Vec::<String>::new()))
812        }
813    }
814
815    struct TestChannel;
816
817    #[async_trait]
818    impl Channel for TestChannel {
819        fn name(&self) -> &str {
820            "test"
821        }
822
823        async fn listen(&self, _tx: mpsc::Sender<IncomingEvent>) -> motosan_chat_core::Result<()> {
824            Ok(())
825        }
826
827        async fn stream_chunk(
828            &self,
829            _user_id: &str,
830            _chunk: &str,
831        ) -> motosan_chat_core::Result<()> {
832            Ok(())
833        }
834
835        async fn stream_done(&self, _user_id: &str) -> motosan_chat_core::Result<()> {
836            Ok(())
837        }
838
839        async fn send(&self, _user_id: &str, _text: &str) -> motosan_chat_core::Result<()> {
840            Ok(())
841        }
842    }
843
844    fn test_thread() -> Thread {
845        Thread::new(
846            IncomingEvent {
847                id: "evt-1".to_string(),
848                platform: "test".to_string(),
849                user_id: "u1".to_string(),
850                channel_id: "c1".to_string(),
851                text: "hello".to_string(),
852                reply_token: None,
853                attachments: vec![],
854                chat_type: ChatType::DirectMessage,
855                kind: IncomingEventKind::Message,
856            },
857            Arc::new(TestChannel),
858        )
859    }
860
861    fn scripted_llm(responses: Vec<LlmResponse>) -> ScriptedLlm {
862        ScriptedLlm {
863            responses: Arc::new(Mutex::new(VecDeque::from(responses))),
864        }
865    }
866
867    #[test]
868    fn place_order_tool_blocks_risk_violations() {
869        let risk_guard = RiskGuard::new(hyper_risk::risk::RiskConfig::default());
870        let tool = PlaceOrderTool::new(risk_guard, AccountState::default());
871
872        let runtime = tokio::runtime::Runtime::new().expect("runtime");
873        let output = runtime.block_on(tool.call(
874            serde_json::json!({"asset": 0, "is_buy": true, "price": "65000", "size": "10000000"}),
875            &ToolContext {
876                caller_id: "test".to_string(),
877                platform: "test".to_string(),
878                extra: Default::default(),
879            },
880        ));
881
882        assert!(output.is_error);
883    }
884
885    #[test]
886    fn do_nothing_tool_returns_noop_json() {
887        let tool = DoNothingTool;
888        let runtime = tokio::runtime::Runtime::new().expect("runtime");
889        let output = runtime.block_on(tool.call(
890            serde_json::json!({"reason": "wait"}),
891            &ToolContext {
892                caller_id: "test".to_string(),
893                platform: "test".to_string(),
894                extra: Default::default(),
895            },
896        ));
897
898        assert!(!output.is_error);
899    }
900
901    #[test]
902    fn order_request_parser_handles_string_numbers() {
903        let req = order_request_from_args(&serde_json::json!({
904            "asset": 0,
905            "is_buy": false,
906            "price": "100.5",
907            "size": "2.0"
908        }));
909
910        assert_eq!(req.symbol, "0");
911        assert_eq!(req.side, "sell");
912        assert!((req.price - 100.5).abs() < f64::EPSILON);
913        assert!((req.size - 2.0).abs() < f64::EPSILON);
914    }
915
916    #[tokio::test]
917    async fn integration_single_tool_call_then_final_answer() {
918        let llm = scripted_llm(vec![
919            LlmResponse::ToolCall {
920                id: "tc-1".to_string(),
921                name: "get_positions".to_string(),
922                args: serde_json::json!({}),
923            },
924            LlmResponse::Message("do_nothing".to_string()),
925        ]);
926
927        let loop_engine = AgentLoop::builder()
928            .tool(GetPositionsTool::new(test_position_manager()))
929            .tool(DoNothingTool)
930            .max_iterations(5)
931            .build();
932
933        let result = loop_engine
934            .run(&llm, &test_thread(), vec![Message::user("status")])
935            .await
936            .expect("agent loop should succeed");
937
938        assert_eq!(result.tool_calls.len(), 1);
939        assert_eq!(result.tool_calls[0].0, "get_positions");
940        assert_eq!(result.answer, "do_nothing");
941    }
942
943    #[tokio::test]
944    async fn integration_positions_then_place_order_then_final_reasoning() {
945        let llm = scripted_llm(vec![
946            LlmResponse::ToolCall {
947                id: "tc-1".to_string(),
948                name: "get_positions".to_string(),
949                args: serde_json::json!({}),
950            },
951            LlmResponse::ToolCall {
952                id: "tc-2".to_string(),
953                name: "place_order".to_string(),
954                args: serde_json::json!({"asset": 0, "is_buy": true, "price": "65000", "size": "0.01"}),
955            },
956            LlmResponse::Message("placed".to_string()),
957        ]);
958
959        let risk_guard = RiskGuard::new(hyper_risk::risk::RiskConfig::default());
960        let loop_engine = AgentLoop::builder()
961            .tool(GetPositionsTool::new(test_position_manager()))
962            .tool(PlaceOrderTool::new(risk_guard, AccountState::default()))
963            .max_iterations(5)
964            .build();
965
966        let result = loop_engine
967            .run(&llm, &test_thread(), vec![Message::user("trade")])
968            .await
969            .expect("agent loop should succeed");
970
971        assert_eq!(result.tool_calls.len(), 2);
972        assert_eq!(result.tool_calls[0].0, "get_positions");
973        assert_eq!(result.tool_calls[1].0, "place_order");
974        assert_eq!(result.answer, "placed");
975    }
976
977    #[tokio::test]
978    async fn integration_risk_blocked_then_llm_adjusts_to_do_nothing() {
979        let llm = scripted_llm(vec![
980            LlmResponse::ToolCall {
981                id: "tc-1".to_string(),
982                name: "place_order".to_string(),
983                args: serde_json::json!({"asset": 0, "is_buy": true, "price": "65000", "size": "10000000"}),
984            },
985            LlmResponse::ToolCall {
986                id: "tc-2".to_string(),
987                name: "do_nothing".to_string(),
988                args: serde_json::json!({"reason": "risk_blocked"}),
989            },
990            LlmResponse::Message("hold".to_string()),
991        ]);
992
993        let risk_guard = RiskGuard::new(hyper_risk::risk::RiskConfig::default());
994        let loop_engine = AgentLoop::builder()
995            .tool(PlaceOrderTool::new(risk_guard, AccountState::default()))
996            .tool(DoNothingTool)
997            .max_iterations(5)
998            .build();
999
1000        let result = loop_engine
1001            .run(&llm, &test_thread(), vec![Message::user("oversized")])
1002            .await
1003            .expect("agent loop should succeed");
1004
1005        assert_eq!(result.tool_calls.len(), 2);
1006        assert_eq!(result.tool_calls[0].0, "place_order");
1007        assert_eq!(result.tool_calls[1].0, "do_nothing");
1008        assert_eq!(result.answer, "hold");
1009    }
1010
1011    #[tokio::test]
1012    async fn integration_max_iterations_guard_triggers_error() {
1013        let llm = scripted_llm(vec![
1014            LlmResponse::ToolCall {
1015                id: "tc-1".to_string(),
1016                name: "do_nothing".to_string(),
1017                args: serde_json::json!({}),
1018            },
1019            LlmResponse::ToolCall {
1020                id: "tc-2".to_string(),
1021                name: "do_nothing".to_string(),
1022                args: serde_json::json!({}),
1023            },
1024            LlmResponse::ToolCall {
1025                id: "tc-3".to_string(),
1026                name: "do_nothing".to_string(),
1027                args: serde_json::json!({}),
1028            },
1029        ]);
1030
1031        let loop_engine = AgentLoop::builder()
1032            .tool(DoNothingTool)
1033            .max_iterations(2)
1034            .build();
1035
1036        let result = loop_engine
1037            .run(&llm, &test_thread(), vec![Message::user("loop")])
1038            .await;
1039
1040        assert!(matches!(result, Err(AgentError::MaxIterations(2))));
1041    }
1042
1043    #[test]
1044    fn close_all_positions_tool_def_is_valid() {
1045        let tool = CloseAllPositionsTool;
1046        let def = tool.def();
1047        assert_eq!(def.name, "close_all_positions");
1048        assert!(!def.description.is_empty());
1049        assert!(def.description.contains("Emergency"));
1050        assert_eq!(def.input_schema["type"], "object");
1051        assert!(def.input_schema["properties"].get("reason").is_some());
1052    }
1053
1054    #[test]
1055    fn close_all_positions_tool_returns_accepted() {
1056        let tool = CloseAllPositionsTool;
1057        let runtime = tokio::runtime::Runtime::new().expect("runtime");
1058        let output = runtime.block_on(tool.call(
1059            serde_json::json!({"reason": "circuit_breaker"}),
1060            &ToolContext {
1061                caller_id: "test".to_string(),
1062                platform: "test".to_string(),
1063                extra: Default::default(),
1064            },
1065        ));
1066
1067        assert!(!output.is_error);
1068    }
1069
1070    #[test]
1071    fn get_funding_rate_tool_def_is_valid() {
1072        let tool = GetFundingRateTool::new(true);
1073        let def = tool.def();
1074        assert_eq!(def.name, "get_funding_rate");
1075        assert!(!def.description.is_empty());
1076        assert!(def.description.contains("funding rate"));
1077        assert_eq!(def.input_schema["type"], "object");
1078        assert!(def.input_schema.get("properties").is_some());
1079        assert!(def.input_schema["properties"].get("symbol").is_some());
1080    }
1081
1082    #[test]
1083    fn get_funding_rate_tool_missing_symbol_returns_error() {
1084        let tool = GetFundingRateTool::new(true);
1085        let runtime = tokio::runtime::Runtime::new().expect("runtime");
1086        let output = runtime.block_on(tool.call(
1087            serde_json::json!({}),
1088            &ToolContext {
1089                caller_id: "test".to_string(),
1090                platform: "test".to_string(),
1091                extra: Default::default(),
1092            },
1093        ));
1094        assert!(output.is_error);
1095    }
1096
1097    #[test]
1098    fn get_market_data_tool_def_is_valid() {
1099        let tool = GetMarketDataTool::new(true);
1100        let def = tool.def();
1101        assert_eq!(def.name, "get_market_data");
1102        assert!(!def.description.is_empty());
1103        assert!(def.description.contains("orderbook"));
1104        assert_eq!(def.input_schema["type"], "object");
1105        assert!(def.input_schema.get("properties").is_some());
1106        assert!(def.input_schema["properties"].get("symbol").is_some());
1107        assert!(def.input_schema["properties"].get("depth").is_some());
1108    }
1109
1110    #[test]
1111    fn get_market_data_tool_missing_symbol_returns_error() {
1112        let tool = GetMarketDataTool::new(true);
1113        let runtime = tokio::runtime::Runtime::new().expect("runtime");
1114        let output = runtime.block_on(tool.call(
1115            serde_json::json!({}),
1116            &ToolContext {
1117                caller_id: "test".to_string(),
1118                platform: "test".to_string(),
1119                extra: Default::default(),
1120            },
1121        ));
1122        assert!(output.is_error);
1123    }
1124
1125    #[test]
1126    fn get_open_orders_tool_def_is_valid() {
1127        let tool = GetOpenOrdersTool::new(true);
1128        let def = tool.def();
1129        assert_eq!(def.name, "get_open_orders");
1130        assert!(!def.description.is_empty());
1131        assert!(def.description.contains("open"));
1132        assert_eq!(def.input_schema["type"], "object");
1133        assert!(def.input_schema.get("properties").is_some());
1134        assert!(def.input_schema["properties"].get("symbol").is_some());
1135    }
1136
1137    #[test]
1138    fn get_open_orders_tool_returns_result_without_symbol() {
1139        let tool = GetOpenOrdersTool::new(true);
1140        let runtime = tokio::runtime::Runtime::new().expect("runtime");
1141        let output = runtime.block_on(tool.call(
1142            serde_json::json!({}),
1143            &ToolContext {
1144                caller_id: "test".to_string(),
1145                platform: "test".to_string(),
1146                extra: Default::default(),
1147            },
1148        ));
1149        assert!(!output.is_error);
1150    }
1151
1152    #[test]
1153    fn get_trade_history_tool_def_is_valid() {
1154        let tool = GetTradeHistoryTool::new(true);
1155        let def = tool.def();
1156        assert_eq!(def.name, "get_trade_history");
1157        assert!(!def.description.is_empty());
1158        assert!(def.description.contains("trade fills"));
1159        assert_eq!(def.input_schema["type"], "object");
1160        assert!(def.input_schema.get("properties").is_some());
1161        assert!(def.input_schema["properties"].get("symbol").is_some());
1162        assert!(def.input_schema["properties"].get("limit").is_some());
1163    }
1164
1165    #[test]
1166    fn get_trade_history_tool_returns_result_without_params() {
1167        let tool = GetTradeHistoryTool::new(true);
1168        let runtime = tokio::runtime::Runtime::new().expect("runtime");
1169        let output = runtime.block_on(tool.call(
1170            serde_json::json!({}),
1171            &ToolContext {
1172                caller_id: "test".to_string(),
1173                platform: "test".to_string(),
1174                extra: Default::default(),
1175            },
1176        ));
1177        assert!(!output.is_error);
1178    }
1179
1180    #[test]
1181    fn get_trade_history_tool_returns_result_with_symbol_and_limit() {
1182        let tool = GetTradeHistoryTool::new(true);
1183        let runtime = tokio::runtime::Runtime::new().expect("runtime");
1184        let output = runtime.block_on(tool.call(
1185            serde_json::json!({"symbol": "BTC-PERP", "limit": 10}),
1186            &ToolContext {
1187                caller_id: "test".to_string(),
1188                platform: "test".to_string(),
1189                extra: Default::default(),
1190            },
1191        ));
1192        assert!(!output.is_error);
1193    }
1194
1195    #[test]
1196    fn get_open_orders_tool_returns_result_with_symbol() {
1197        let tool = GetOpenOrdersTool::new(true);
1198        let runtime = tokio::runtime::Runtime::new().expect("runtime");
1199        let output = runtime.block_on(tool.call(
1200            serde_json::json!({"symbol": "BTC-PERP"}),
1201            &ToolContext {
1202                caller_id: "test".to_string(),
1203                platform: "test".to_string(),
1204                extra: Default::default(),
1205            },
1206        ));
1207        assert!(!output.is_error);
1208    }
1209}