1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use hyper_exchange::ExchangeClient;
6use hyper_risk::risk::{record_risk_alert, AccountState, OrderRequest, RiskGuard};
7use motosan_chat_tool::{Tool, ToolContext, ToolDef, ToolResult, Value};
8
9use hyper_agent_core::position_manager::PositionManager;
10
11pub struct GetPositionsTool {
12 position_manager: Arc<PositionManager>,
13}
14
15impl GetPositionsTool {
16 pub fn new(position_manager: Arc<PositionManager>) -> Self {
17 Self { position_manager }
18 }
19}
20
21impl Tool for GetPositionsTool {
22 fn def(&self) -> ToolDef {
23 ToolDef {
24 name: "get_positions".to_string(),
25 description: "Get current open positions and account equity".to_string(),
26 input_schema: serde_json::json!({"type": "object", "properties": {}}),
27 }
28 }
29
30 fn call(
31 &self,
32 _args: Value,
33 _ctx: &ToolContext,
34 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
35 let pm = self.position_manager.clone();
36 Box::pin(async move {
37 match pm.list_open().await {
38 Ok(positions) => {
39 let count = positions.len();
40 let pos_json: Vec<serde_json::Value> = positions
41 .iter()
42 .map(|p| {
43 serde_json::json!({
44 "id": p.id,
45 "market": p.market,
46 "side": p.side,
47 "size": p.size,
48 "entry_price": p.entry_price,
49 "current_price": p.current_price,
50 "pnl": p.pnl,
51 "mode": p.mode,
52 })
53 })
54 .collect();
55 ToolResult::json(serde_json::json!({"positions": pos_json, "count": count}))
56 }
57 Err(e) => ToolResult::error(format!("Failed to list positions: {}", e)),
58 }
59 })
60 }
61}
62
63pub struct PlaceOrderTool {
64 risk_guard: RiskGuard,
65 account_state: AccountState,
66}
67
68impl PlaceOrderTool {
69 pub fn new(risk_guard: RiskGuard, account_state: AccountState) -> Self {
70 Self {
71 risk_guard,
72 account_state,
73 }
74 }
75}
76
77impl Tool for PlaceOrderTool {
78 fn def(&self) -> ToolDef {
79 ToolDef {
80 name: "place_order".to_string(),
81 description: "Submit a place order request".to_string(),
82 input_schema: serde_json::json!({
83 "type": "object",
84 "properties": {
85 "asset": {"type": "number"},
86 "is_buy": {"type": "boolean"},
87 "price": {"type": ["string", "number"]},
88 "size": {"type": ["string", "number"]}
89 },
90 "required": ["asset", "is_buy", "price", "size"]
91 }),
92 }
93 }
94
95 fn call(
96 &self,
97 args: Value,
98 _ctx: &ToolContext,
99 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
100 Box::pin(async move {
101 let order_req = order_request_from_args(&args);
102 match self.risk_guard.check_order(&order_req, &self.account_state) {
103 Ok(()) => ToolResult::json(serde_json::json!({
104 "status": "accepted",
105 "request": args,
106 })),
107 Err(violation) => {
108 record_risk_alert(&violation);
109 ToolResult::error(violation.to_string())
110 }
111 }
112 })
113 }
114}
115
116pub struct CancelOrderTool {
117 risk_guard: RiskGuard,
118 account_state: AccountState,
119}
120
121impl CancelOrderTool {
122 pub fn new(risk_guard: RiskGuard, account_state: AccountState) -> Self {
123 Self {
124 risk_guard,
125 account_state,
126 }
127 }
128}
129
130impl Tool for CancelOrderTool {
131 fn def(&self) -> ToolDef {
132 ToolDef {
133 name: "cancel_order".to_string(),
134 description: "Cancel an existing order by id".to_string(),
135 input_schema: serde_json::json!({
136 "type": "object",
137 "properties": {"order_id": {"type": ["string", "number"]}},
138 "required": ["order_id"]
139 }),
140 }
141 }
142
143 fn call(
144 &self,
145 args: Value,
146 _ctx: &ToolContext,
147 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
148 Box::pin(async move {
149 let synthetic_request = OrderRequest {
150 symbol: "unknown".to_string(),
151 side: "sell".to_string(),
152 size: 0.0,
153 price: 0.0,
154 };
155 match self
156 .risk_guard
157 .check_order(&synthetic_request, &self.account_state)
158 {
159 Ok(()) => ToolResult::json(serde_json::json!({
160 "status": "accepted",
161 "request": args,
162 })),
163 Err(violation) => {
164 record_risk_alert(&violation);
165 ToolResult::error(violation.to_string())
166 }
167 }
168 })
169 }
170}
171
172pub struct DoNothingTool;
173
174impl Tool for DoNothingTool {
175 fn def(&self) -> ToolDef {
176 ToolDef {
177 name: "do_nothing".to_string(),
178 description: "Return hold/no-action decision".to_string(),
179 input_schema: serde_json::json!({
180 "type": "object",
181 "properties": {"reason": {"type": "string"}}
182 }),
183 }
184 }
185
186 fn call(
187 &self,
188 args: Value,
189 _ctx: &ToolContext,
190 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
191 Box::pin(async move {
192 ToolResult::json(serde_json::json!({
193 "status": "noop",
194 "reason": args.get("reason").and_then(|v| v.as_str()).unwrap_or("No action")
195 }))
196 })
197 }
198}
199
200pub struct SetStopLossTool;
201
202impl Tool for SetStopLossTool {
203 fn def(&self) -> ToolDef {
204 ToolDef {
205 name: "set_stop_loss".to_string(),
206 description: "Set a stop-loss trigger on an existing position. Triggers a market close when the price reaches the specified trigger price.".to_string(),
207 input_schema: serde_json::json!({
208 "type": "object",
209 "properties": {
210 "symbol": {"type": "string", "description": "Trading pair symbol, e.g. BTC-PERP"},
211 "trigger_price": {"type": "number", "description": "Price at which the stop-loss triggers"}
212 },
213 "required": ["symbol", "trigger_price"]
214 }),
215 }
216 }
217
218 fn call(
219 &self,
220 args: Value,
221 _ctx: &ToolContext,
222 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
223 Box::pin(async move {
224 let symbol = args
225 .get("symbol")
226 .and_then(|v| v.as_str())
227 .unwrap_or("unknown");
228 let trigger_price = args
229 .get("trigger_price")
230 .and_then(|v| v.as_f64())
231 .unwrap_or(0.0);
232 ToolResult::json(serde_json::json!({
233 "status": "accepted",
234 "trigger_type": "stop_loss",
235 "symbol": symbol,
236 "trigger_price": trigger_price
237 }))
238 })
239 }
240}
241
242pub struct SetTakeProfitTool;
243
244impl Tool for SetTakeProfitTool {
245 fn def(&self) -> ToolDef {
246 ToolDef {
247 name: "set_take_profit".to_string(),
248 description: "Set a take-profit trigger on an existing position. Triggers a market close when the price reaches the specified trigger price.".to_string(),
249 input_schema: serde_json::json!({
250 "type": "object",
251 "properties": {
252 "symbol": {"type": "string", "description": "Trading pair symbol, e.g. BTC-PERP"},
253 "trigger_price": {"type": "number", "description": "Price at which the take-profit triggers"}
254 },
255 "required": ["symbol", "trigger_price"]
256 }),
257 }
258 }
259
260 fn call(
261 &self,
262 args: Value,
263 _ctx: &ToolContext,
264 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
265 Box::pin(async move {
266 let symbol = args
267 .get("symbol")
268 .and_then(|v| v.as_str())
269 .unwrap_or("unknown");
270 let trigger_price = args
271 .get("trigger_price")
272 .and_then(|v| v.as_f64())
273 .unwrap_or(0.0);
274 ToolResult::json(serde_json::json!({
275 "status": "accepted",
276 "trigger_type": "take_profit",
277 "symbol": symbol,
278 "trigger_price": trigger_price
279 }))
280 })
281 }
282}
283
284pub struct GetOpenOrdersTool {
285 is_mainnet: bool,
286}
287
288impl GetOpenOrdersTool {
289 pub fn new(is_mainnet: bool) -> Self {
290 Self { is_mainnet }
291 }
292}
293
294impl Tool for GetOpenOrdersTool {
295 fn def(&self) -> ToolDef {
296 ToolDef {
297 name: "get_open_orders".to_string(),
298 description: "List all open (pending/resting) orders. Optionally filter by a specific trading pair symbol. Returns order details including price, size, side, and order ID.".to_string(),
299 input_schema: serde_json::json!({
300 "type": "object",
301 "properties": {
302 "symbol": {"type": "string", "description": "Optional trading pair symbol to filter by, e.g. BTC-PERP"}
303 }
304 }),
305 }
306 }
307
308 fn call(
309 &self,
310 args: Value,
311 _ctx: &ToolContext,
312 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
313 Box::pin(async move {
314 let symbol_filter = args.get("symbol").and_then(|v| v.as_str());
315 match fetch_open_orders(symbol_filter, self.is_mainnet).await {
316 Ok(info) => ToolResult::json(info),
317 Err(e) => ToolResult::error(format!("Failed to fetch open orders: {}", e)),
318 }
319 })
320 }
321}
322
323pub async fn fetch_open_orders(
328 symbol_filter: Option<&str>,
329 _is_mainnet: bool,
330) -> Result<serde_json::Value, String> {
331 Ok(serde_json::json!({
335 "orders": [],
336 "symbol_filter": symbol_filter,
337 "message": "Open orders query accepted"
338 }))
339}
340
341pub struct CloseAllPositionsTool;
342
343impl Tool for CloseAllPositionsTool {
344 fn def(&self) -> ToolDef {
345 ToolDef {
346 name: "close_all_positions".to_string(),
347 description: "Emergency: close ALL open positions immediately with market orders. Use when circuit breaker triggers or emergency situations arise.".to_string(),
348 input_schema: serde_json::json!({
349 "type": "object",
350 "properties": {
351 "reason": {"type": "string", "description": "Reason for the emergency close"}
352 },
353 "required": ["reason"]
354 }),
355 }
356 }
357
358 fn call(
359 &self,
360 args: Value,
361 _ctx: &ToolContext,
362 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
363 Box::pin(async move {
364 let reason = args
365 .get("reason")
366 .and_then(|v| v.as_str())
367 .unwrap_or("emergency");
368 ToolResult::json(serde_json::json!({
369 "status": "accepted",
370 "action": "close_all_positions",
371 "reason": reason
372 }))
373 })
374 }
375}
376
377pub struct GetFundingRateTool {
378 is_mainnet: bool,
379}
380
381impl GetFundingRateTool {
382 pub fn new(is_mainnet: bool) -> Self {
383 Self { is_mainnet }
384 }
385}
386
387impl Tool for GetFundingRateTool {
388 fn def(&self) -> ToolDef {
389 ToolDef {
390 name: "get_funding_rate".to_string(),
391 description: "Get current funding rate for a perpetual contract symbol. Use to assess overnight holding cost.".to_string(),
392 input_schema: serde_json::json!({
393 "type": "object",
394 "properties": {
395 "symbol": {"type": "string", "description": "Trading pair symbol, e.g. BTC-PERP"}
396 },
397 "required": ["symbol"]
398 }),
399 }
400 }
401
402 fn call(
403 &self,
404 args: Value,
405 _ctx: &ToolContext,
406 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
407 Box::pin(async move {
408 let symbol = match args.get("symbol").and_then(|v| v.as_str()) {
409 Some(s) => s,
410 None => return ToolResult::error("Missing required parameter: symbol"),
411 };
412
413 match fetch_funding_rate(symbol, self.is_mainnet).await {
414 Ok(info) => ToolResult::json(info),
415 Err(e) => ToolResult::error(format!("Failed to fetch funding rate: {}", e)),
416 }
417 })
418 }
419}
420
421pub async fn fetch_funding_rate(
425 symbol: &str,
426 is_mainnet: bool,
427) -> Result<serde_json::Value, String> {
428 let client = ExchangeClient::new(is_mainnet);
429
430 let request = serde_json::json!({
432 "type": "metaAndAssetCtxs"
433 });
434
435 let response = client
436 .post_info(request)
437 .await
438 .map_err(|e| format!("API error: {}", e))?;
439
440 let arr = response
442 .as_array()
443 .ok_or_else(|| "Invalid response format".to_string())?;
444
445 if arr.len() < 2 {
446 return Err("Incomplete response from API".to_string());
447 }
448
449 let meta_obj = &arr[0];
450 let asset_ctxs = arr[1]
451 .as_array()
452 .ok_or_else(|| "Invalid asset contexts format".to_string())?;
453
454 let universe = meta_obj
455 .get("universe")
456 .and_then(|u| u.as_array())
457 .ok_or_else(|| "Missing universe in meta".to_string())?;
458
459 let coin = symbol
461 .replace("-PERP", "")
462 .replace("-USDC", "")
463 .replace("-USD", "");
464
465 let idx = universe
466 .iter()
467 .position(|item| item.get("name").and_then(|n| n.as_str()) == Some(&coin))
468 .ok_or_else(|| format!("Symbol {} not found", symbol))?;
469
470 if idx >= asset_ctxs.len() {
471 return Err(format!("Asset context not found for {}", symbol));
472 }
473
474 let ctx = &asset_ctxs[idx];
475
476 let funding_rate = ctx
477 .get("funding")
478 .and_then(|v| v.as_str())
479 .and_then(|s| s.parse::<f64>().ok())
480 .unwrap_or(0.0);
481
482 let mark_price = ctx
483 .get("markPx")
484 .and_then(|v| v.as_str())
485 .and_then(|s| s.parse::<f64>().ok())
486 .unwrap_or(0.0);
487
488 let open_interest = ctx
489 .get("openInterest")
490 .and_then(|v| v.as_str())
491 .and_then(|s| s.parse::<f64>().ok())
492 .unwrap_or(0.0);
493
494 let annualized_rate = funding_rate * 3.0 * 365.0;
496
497 Ok(serde_json::json!({
498 "symbol": symbol,
499 "funding_rate": funding_rate,
500 "funding_rate_pct": format!("{:.4}%", funding_rate * 100.0),
501 "annualized_rate_pct": format!("{:.2}%", annualized_rate * 100.0),
502 "mark_price": mark_price,
503 "open_interest": open_interest,
504 "direction": if funding_rate > 0.0 { "longs_pay_shorts" } else if funding_rate < 0.0 { "shorts_pay_longs" } else { "neutral" }
505 }))
506}
507
508pub struct GetMarketDataTool {
509 is_mainnet: bool,
510}
511
512impl GetMarketDataTool {
513 pub fn new(is_mainnet: bool) -> Self {
514 Self { is_mainnet }
515 }
516}
517
518impl Tool for GetMarketDataTool {
519 fn def(&self) -> ToolDef {
520 ToolDef {
521 name: "get_market_data".to_string(),
522 description: "Get orderbook (bids/asks) for a symbol. Use to check spread and liquidity before placing limit orders.".to_string(),
523 input_schema: serde_json::json!({
524 "type": "object",
525 "properties": {
526 "symbol": {"type": "string", "description": "Trading pair symbol, e.g. BTC-PERP"},
527 "depth": {"type": "integer", "description": "Number of orderbook levels to return (default 5)"}
528 },
529 "required": ["symbol"]
530 }),
531 }
532 }
533
534 fn call(
535 &self,
536 args: Value,
537 _ctx: &ToolContext,
538 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
539 Box::pin(async move {
540 let symbol = match args.get("symbol").and_then(|v| v.as_str()) {
541 Some(s) => s,
542 None => return ToolResult::error("Missing required parameter: symbol"),
543 };
544 let depth = args.get("depth").and_then(|v| v.as_u64()).unwrap_or(5) as usize;
545
546 match fetch_market_data(symbol, depth, self.is_mainnet).await {
547 Ok(info) => ToolResult::json(info),
548 Err(e) => ToolResult::error(format!("Failed to fetch market data: {}", e)),
549 }
550 })
551 }
552}
553
554pub async fn fetch_market_data(
558 symbol: &str,
559 depth: usize,
560 is_mainnet: bool,
561) -> Result<serde_json::Value, String> {
562 let client = ExchangeClient::new(is_mainnet);
563
564 let coin = symbol
566 .replace("-PERP", "")
567 .replace("-USDC", "")
568 .replace("-USD", "");
569
570 let request = serde_json::json!({
571 "type": "l2Book",
572 "coin": coin
573 });
574
575 let response = client
576 .post_info(request)
577 .await
578 .map_err(|e| format!("API error: {}", e))?;
579
580 let levels = response
582 .get("levels")
583 .and_then(|l| l.as_array())
584 .ok_or_else(|| "Invalid L2 book response format".to_string())?;
585
586 if levels.len() < 2 {
587 return Err("Incomplete L2 book response".to_string());
588 }
589
590 let raw_bids = levels[0]
591 .as_array()
592 .ok_or_else(|| "Invalid bids format".to_string())?;
593 let raw_asks = levels[1]
594 .as_array()
595 .ok_or_else(|| "Invalid asks format".to_string())?;
596
597 let depth = depth.min(20); let parse_level = |entry: &serde_json::Value| -> Option<serde_json::Value> {
600 let px = entry
601 .get("px")
602 .and_then(|v| v.as_str())?
603 .parse::<f64>()
604 .ok()?;
605 let sz = entry
606 .get("sz")
607 .and_then(|v| v.as_str())?
608 .parse::<f64>()
609 .ok()?;
610 let n = entry.get("n").and_then(|v| v.as_u64()).unwrap_or(0);
611 Some(serde_json::json!({
612 "price": px,
613 "size": sz,
614 "num_orders": n
615 }))
616 };
617
618 let bids: Vec<serde_json::Value> = raw_bids
619 .iter()
620 .take(depth)
621 .filter_map(parse_level)
622 .collect();
623 let asks: Vec<serde_json::Value> = raw_asks
624 .iter()
625 .take(depth)
626 .filter_map(parse_level)
627 .collect();
628
629 let best_bid = bids
631 .first()
632 .and_then(|b| b.get("price"))
633 .and_then(|v| v.as_f64());
634 let best_ask = asks
635 .first()
636 .and_then(|a| a.get("price"))
637 .and_then(|v| v.as_f64());
638
639 let (spread, spread_pct, mid_price) = match (best_bid, best_ask) {
640 (Some(bid), Some(ask)) => {
641 let s = ask - bid;
642 let mid = (ask + bid) / 2.0;
643 let pct = if mid > 0.0 { s / mid * 100.0 } else { 0.0 };
644 (Some(s), Some(pct), Some(mid))
645 }
646 _ => (None, None, None),
647 };
648
649 Ok(serde_json::json!({
650 "symbol": symbol,
651 "bids": bids,
652 "asks": asks,
653 "best_bid": best_bid,
654 "best_ask": best_ask,
655 "mid_price": mid_price,
656 "spread": spread,
657 "spread_pct": spread_pct.map(|p| format!("{:.4}%", p)),
658 "depth": depth
659 }))
660}
661
662pub struct GetTradeHistoryTool {
663 is_mainnet: bool,
664}
665
666impl GetTradeHistoryTool {
667 pub fn new(is_mainnet: bool) -> Self {
668 Self { is_mainnet }
669 }
670}
671
672impl Tool for GetTradeHistoryTool {
673 fn def(&self) -> ToolDef {
674 ToolDef {
675 name: "get_trade_history".to_string(),
676 description: "Get recent trade fills (executed trades) for the current user. Optionally filter by symbol and limit the number of results.".to_string(),
677 input_schema: serde_json::json!({
678 "type": "object",
679 "properties": {
680 "symbol": {"type": "string", "description": "Optional trading pair symbol to filter by, e.g. BTC-PERP"},
681 "limit": {"type": "integer", "description": "Maximum number of fills to return (default: all)"}
682 }
683 }),
684 }
685 }
686
687 fn call(
688 &self,
689 args: Value,
690 _ctx: &ToolContext,
691 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
692 Box::pin(async move {
693 let symbol_filter = args.get("symbol").and_then(|v| v.as_str());
694 let limit = args
695 .get("limit")
696 .and_then(|v| v.as_u64())
697 .map(|l| l as usize);
698 match fetch_trade_history(symbol_filter, limit, self.is_mainnet).await {
699 Ok(info) => ToolResult::json(info),
700 Err(e) => ToolResult::error(format!("Failed to fetch trade history: {}", e)),
701 }
702 })
703 }
704}
705
706pub async fn fetch_trade_history(
712 symbol_filter: Option<&str>,
713 limit: Option<usize>,
714 _is_mainnet: bool,
715) -> Result<serde_json::Value, String> {
716 Ok(serde_json::json!({
717 "fills": [],
718 "symbol_filter": symbol_filter,
719 "limit": limit,
720 "message": "Trade history query accepted"
721 }))
722}
723
724fn order_request_from_args(args: &Value) -> OrderRequest {
725 let symbol = args
726 .get("asset")
727 .and_then(|v| v.as_u64())
728 .map(|a| a.to_string())
729 .or_else(|| {
730 args.get("symbol")
731 .and_then(|v| v.as_str())
732 .map(|s| s.to_string())
733 })
734 .unwrap_or_else(|| "unknown".to_string());
735
736 let side = args
737 .get("is_buy")
738 .and_then(|v| v.as_bool())
739 .map(|b| if b { "buy" } else { "sell" }.to_string())
740 .or_else(|| {
741 args.get("side")
742 .and_then(|v| v.as_str())
743 .map(ToString::to_string)
744 })
745 .unwrap_or_else(|| "buy".to_string());
746
747 let size = args
748 .get("size")
749 .and_then(|v| {
750 v.as_str()
751 .and_then(|s| s.parse().ok())
752 .or_else(|| v.as_f64())
753 })
754 .unwrap_or(0.0);
755
756 let price = args
757 .get("price")
758 .and_then(|v| {
759 v.as_str()
760 .and_then(|s| s.parse().ok())
761 .or_else(|| v.as_f64())
762 })
763 .unwrap_or(0.0);
764
765 OrderRequest {
766 symbol,
767 side,
768 size,
769 price,
770 }
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776 use std::collections::VecDeque;
777 use std::sync::{Arc, Mutex};
778
779 use async_trait::async_trait;
780 use futures_util::stream;
781 use hyper_agent_core::position_manager::PositionManager;
782 use motosan_chat_agent::{AgentError, AgentLoop, LlmClient, LlmResponse, LlmStream, Message};
783 use motosan_chat_core::{Channel, ChatType, IncomingEvent, IncomingEventKind, Thread};
784 use motosan_chat_tool::ToolDef;
785 use tokio::sync::mpsc;
786
787 fn test_position_manager() -> Arc<PositionManager> {
788 Arc::new(PositionManager::new(":memory:").expect("in-memory position manager"))
789 }
790
791 #[derive(Clone)]
792 struct ScriptedLlm {
793 responses: Arc<Mutex<VecDeque<LlmResponse>>>,
794 }
795
796 #[async_trait]
797 impl LlmClient for ScriptedLlm {
798 async fn chat(
799 &self,
800 _messages: Vec<Message>,
801 _tools: &[ToolDef],
802 ) -> motosan_chat_tool::Result<LlmResponse> {
803 self.responses
804 .lock()
805 .expect("responses mutex")
806 .pop_front()
807 .ok_or_else(|| motosan_chat_tool::Error::new("no scripted response"))
808 }
809
810 async fn stream(&self, _messages: Vec<Message>) -> LlmStream {
811 Box::pin(stream::iter(Vec::<String>::new()))
812 }
813 }
814
815 struct TestChannel;
816
817 #[async_trait]
818 impl Channel for TestChannel {
819 fn name(&self) -> &str {
820 "test"
821 }
822
823 async fn listen(&self, _tx: mpsc::Sender<IncomingEvent>) -> motosan_chat_core::Result<()> {
824 Ok(())
825 }
826
827 async fn stream_chunk(
828 &self,
829 _user_id: &str,
830 _chunk: &str,
831 ) -> motosan_chat_core::Result<()> {
832 Ok(())
833 }
834
835 async fn stream_done(&self, _user_id: &str) -> motosan_chat_core::Result<()> {
836 Ok(())
837 }
838
839 async fn send(&self, _user_id: &str, _text: &str) -> motosan_chat_core::Result<()> {
840 Ok(())
841 }
842 }
843
844 fn test_thread() -> Thread {
845 Thread::new(
846 IncomingEvent {
847 id: "evt-1".to_string(),
848 platform: "test".to_string(),
849 user_id: "u1".to_string(),
850 channel_id: "c1".to_string(),
851 text: "hello".to_string(),
852 reply_token: None,
853 attachments: vec![],
854 chat_type: ChatType::DirectMessage,
855 kind: IncomingEventKind::Message,
856 },
857 Arc::new(TestChannel),
858 )
859 }
860
861 fn scripted_llm(responses: Vec<LlmResponse>) -> ScriptedLlm {
862 ScriptedLlm {
863 responses: Arc::new(Mutex::new(VecDeque::from(responses))),
864 }
865 }
866
867 #[test]
868 fn place_order_tool_blocks_risk_violations() {
869 let risk_guard = RiskGuard::new(hyper_risk::risk::RiskConfig::default());
870 let tool = PlaceOrderTool::new(risk_guard, AccountState::default());
871
872 let runtime = tokio::runtime::Runtime::new().expect("runtime");
873 let output = runtime.block_on(tool.call(
874 serde_json::json!({"asset": 0, "is_buy": true, "price": "65000", "size": "10000000"}),
875 &ToolContext {
876 caller_id: "test".to_string(),
877 platform: "test".to_string(),
878 extra: Default::default(),
879 },
880 ));
881
882 assert!(output.is_error);
883 }
884
885 #[test]
886 fn do_nothing_tool_returns_noop_json() {
887 let tool = DoNothingTool;
888 let runtime = tokio::runtime::Runtime::new().expect("runtime");
889 let output = runtime.block_on(tool.call(
890 serde_json::json!({"reason": "wait"}),
891 &ToolContext {
892 caller_id: "test".to_string(),
893 platform: "test".to_string(),
894 extra: Default::default(),
895 },
896 ));
897
898 assert!(!output.is_error);
899 }
900
901 #[test]
902 fn order_request_parser_handles_string_numbers() {
903 let req = order_request_from_args(&serde_json::json!({
904 "asset": 0,
905 "is_buy": false,
906 "price": "100.5",
907 "size": "2.0"
908 }));
909
910 assert_eq!(req.symbol, "0");
911 assert_eq!(req.side, "sell");
912 assert!((req.price - 100.5).abs() < f64::EPSILON);
913 assert!((req.size - 2.0).abs() < f64::EPSILON);
914 }
915
916 #[tokio::test]
917 async fn integration_single_tool_call_then_final_answer() {
918 let llm = scripted_llm(vec![
919 LlmResponse::ToolCall {
920 id: "tc-1".to_string(),
921 name: "get_positions".to_string(),
922 args: serde_json::json!({}),
923 },
924 LlmResponse::Message("do_nothing".to_string()),
925 ]);
926
927 let loop_engine = AgentLoop::builder()
928 .tool(GetPositionsTool::new(test_position_manager()))
929 .tool(DoNothingTool)
930 .max_iterations(5)
931 .build();
932
933 let result = loop_engine
934 .run(&llm, &test_thread(), vec![Message::user("status")])
935 .await
936 .expect("agent loop should succeed");
937
938 assert_eq!(result.tool_calls.len(), 1);
939 assert_eq!(result.tool_calls[0].0, "get_positions");
940 assert_eq!(result.answer, "do_nothing");
941 }
942
943 #[tokio::test]
944 async fn integration_positions_then_place_order_then_final_reasoning() {
945 let llm = scripted_llm(vec![
946 LlmResponse::ToolCall {
947 id: "tc-1".to_string(),
948 name: "get_positions".to_string(),
949 args: serde_json::json!({}),
950 },
951 LlmResponse::ToolCall {
952 id: "tc-2".to_string(),
953 name: "place_order".to_string(),
954 args: serde_json::json!({"asset": 0, "is_buy": true, "price": "65000", "size": "0.01"}),
955 },
956 LlmResponse::Message("placed".to_string()),
957 ]);
958
959 let risk_guard = RiskGuard::new(hyper_risk::risk::RiskConfig::default());
960 let loop_engine = AgentLoop::builder()
961 .tool(GetPositionsTool::new(test_position_manager()))
962 .tool(PlaceOrderTool::new(risk_guard, AccountState::default()))
963 .max_iterations(5)
964 .build();
965
966 let result = loop_engine
967 .run(&llm, &test_thread(), vec![Message::user("trade")])
968 .await
969 .expect("agent loop should succeed");
970
971 assert_eq!(result.tool_calls.len(), 2);
972 assert_eq!(result.tool_calls[0].0, "get_positions");
973 assert_eq!(result.tool_calls[1].0, "place_order");
974 assert_eq!(result.answer, "placed");
975 }
976
977 #[tokio::test]
978 async fn integration_risk_blocked_then_llm_adjusts_to_do_nothing() {
979 let llm = scripted_llm(vec![
980 LlmResponse::ToolCall {
981 id: "tc-1".to_string(),
982 name: "place_order".to_string(),
983 args: serde_json::json!({"asset": 0, "is_buy": true, "price": "65000", "size": "10000000"}),
984 },
985 LlmResponse::ToolCall {
986 id: "tc-2".to_string(),
987 name: "do_nothing".to_string(),
988 args: serde_json::json!({"reason": "risk_blocked"}),
989 },
990 LlmResponse::Message("hold".to_string()),
991 ]);
992
993 let risk_guard = RiskGuard::new(hyper_risk::risk::RiskConfig::default());
994 let loop_engine = AgentLoop::builder()
995 .tool(PlaceOrderTool::new(risk_guard, AccountState::default()))
996 .tool(DoNothingTool)
997 .max_iterations(5)
998 .build();
999
1000 let result = loop_engine
1001 .run(&llm, &test_thread(), vec![Message::user("oversized")])
1002 .await
1003 .expect("agent loop should succeed");
1004
1005 assert_eq!(result.tool_calls.len(), 2);
1006 assert_eq!(result.tool_calls[0].0, "place_order");
1007 assert_eq!(result.tool_calls[1].0, "do_nothing");
1008 assert_eq!(result.answer, "hold");
1009 }
1010
1011 #[tokio::test]
1012 async fn integration_max_iterations_guard_triggers_error() {
1013 let llm = scripted_llm(vec![
1014 LlmResponse::ToolCall {
1015 id: "tc-1".to_string(),
1016 name: "do_nothing".to_string(),
1017 args: serde_json::json!({}),
1018 },
1019 LlmResponse::ToolCall {
1020 id: "tc-2".to_string(),
1021 name: "do_nothing".to_string(),
1022 args: serde_json::json!({}),
1023 },
1024 LlmResponse::ToolCall {
1025 id: "tc-3".to_string(),
1026 name: "do_nothing".to_string(),
1027 args: serde_json::json!({}),
1028 },
1029 ]);
1030
1031 let loop_engine = AgentLoop::builder()
1032 .tool(DoNothingTool)
1033 .max_iterations(2)
1034 .build();
1035
1036 let result = loop_engine
1037 .run(&llm, &test_thread(), vec![Message::user("loop")])
1038 .await;
1039
1040 assert!(matches!(result, Err(AgentError::MaxIterations(2))));
1041 }
1042
1043 #[test]
1044 fn close_all_positions_tool_def_is_valid() {
1045 let tool = CloseAllPositionsTool;
1046 let def = tool.def();
1047 assert_eq!(def.name, "close_all_positions");
1048 assert!(!def.description.is_empty());
1049 assert!(def.description.contains("Emergency"));
1050 assert_eq!(def.input_schema["type"], "object");
1051 assert!(def.input_schema["properties"].get("reason").is_some());
1052 }
1053
1054 #[test]
1055 fn close_all_positions_tool_returns_accepted() {
1056 let tool = CloseAllPositionsTool;
1057 let runtime = tokio::runtime::Runtime::new().expect("runtime");
1058 let output = runtime.block_on(tool.call(
1059 serde_json::json!({"reason": "circuit_breaker"}),
1060 &ToolContext {
1061 caller_id: "test".to_string(),
1062 platform: "test".to_string(),
1063 extra: Default::default(),
1064 },
1065 ));
1066
1067 assert!(!output.is_error);
1068 }
1069
1070 #[test]
1071 fn get_funding_rate_tool_def_is_valid() {
1072 let tool = GetFundingRateTool::new(true);
1073 let def = tool.def();
1074 assert_eq!(def.name, "get_funding_rate");
1075 assert!(!def.description.is_empty());
1076 assert!(def.description.contains("funding rate"));
1077 assert_eq!(def.input_schema["type"], "object");
1078 assert!(def.input_schema.get("properties").is_some());
1079 assert!(def.input_schema["properties"].get("symbol").is_some());
1080 }
1081
1082 #[test]
1083 fn get_funding_rate_tool_missing_symbol_returns_error() {
1084 let tool = GetFundingRateTool::new(true);
1085 let runtime = tokio::runtime::Runtime::new().expect("runtime");
1086 let output = runtime.block_on(tool.call(
1087 serde_json::json!({}),
1088 &ToolContext {
1089 caller_id: "test".to_string(),
1090 platform: "test".to_string(),
1091 extra: Default::default(),
1092 },
1093 ));
1094 assert!(output.is_error);
1095 }
1096
1097 #[test]
1098 fn get_market_data_tool_def_is_valid() {
1099 let tool = GetMarketDataTool::new(true);
1100 let def = tool.def();
1101 assert_eq!(def.name, "get_market_data");
1102 assert!(!def.description.is_empty());
1103 assert!(def.description.contains("orderbook"));
1104 assert_eq!(def.input_schema["type"], "object");
1105 assert!(def.input_schema.get("properties").is_some());
1106 assert!(def.input_schema["properties"].get("symbol").is_some());
1107 assert!(def.input_schema["properties"].get("depth").is_some());
1108 }
1109
1110 #[test]
1111 fn get_market_data_tool_missing_symbol_returns_error() {
1112 let tool = GetMarketDataTool::new(true);
1113 let runtime = tokio::runtime::Runtime::new().expect("runtime");
1114 let output = runtime.block_on(tool.call(
1115 serde_json::json!({}),
1116 &ToolContext {
1117 caller_id: "test".to_string(),
1118 platform: "test".to_string(),
1119 extra: Default::default(),
1120 },
1121 ));
1122 assert!(output.is_error);
1123 }
1124
1125 #[test]
1126 fn get_open_orders_tool_def_is_valid() {
1127 let tool = GetOpenOrdersTool::new(true);
1128 let def = tool.def();
1129 assert_eq!(def.name, "get_open_orders");
1130 assert!(!def.description.is_empty());
1131 assert!(def.description.contains("open"));
1132 assert_eq!(def.input_schema["type"], "object");
1133 assert!(def.input_schema.get("properties").is_some());
1134 assert!(def.input_schema["properties"].get("symbol").is_some());
1135 }
1136
1137 #[test]
1138 fn get_open_orders_tool_returns_result_without_symbol() {
1139 let tool = GetOpenOrdersTool::new(true);
1140 let runtime = tokio::runtime::Runtime::new().expect("runtime");
1141 let output = runtime.block_on(tool.call(
1142 serde_json::json!({}),
1143 &ToolContext {
1144 caller_id: "test".to_string(),
1145 platform: "test".to_string(),
1146 extra: Default::default(),
1147 },
1148 ));
1149 assert!(!output.is_error);
1150 }
1151
1152 #[test]
1153 fn get_trade_history_tool_def_is_valid() {
1154 let tool = GetTradeHistoryTool::new(true);
1155 let def = tool.def();
1156 assert_eq!(def.name, "get_trade_history");
1157 assert!(!def.description.is_empty());
1158 assert!(def.description.contains("trade fills"));
1159 assert_eq!(def.input_schema["type"], "object");
1160 assert!(def.input_schema.get("properties").is_some());
1161 assert!(def.input_schema["properties"].get("symbol").is_some());
1162 assert!(def.input_schema["properties"].get("limit").is_some());
1163 }
1164
1165 #[test]
1166 fn get_trade_history_tool_returns_result_without_params() {
1167 let tool = GetTradeHistoryTool::new(true);
1168 let runtime = tokio::runtime::Runtime::new().expect("runtime");
1169 let output = runtime.block_on(tool.call(
1170 serde_json::json!({}),
1171 &ToolContext {
1172 caller_id: "test".to_string(),
1173 platform: "test".to_string(),
1174 extra: Default::default(),
1175 },
1176 ));
1177 assert!(!output.is_error);
1178 }
1179
1180 #[test]
1181 fn get_trade_history_tool_returns_result_with_symbol_and_limit() {
1182 let tool = GetTradeHistoryTool::new(true);
1183 let runtime = tokio::runtime::Runtime::new().expect("runtime");
1184 let output = runtime.block_on(tool.call(
1185 serde_json::json!({"symbol": "BTC-PERP", "limit": 10}),
1186 &ToolContext {
1187 caller_id: "test".to_string(),
1188 platform: "test".to_string(),
1189 extra: Default::default(),
1190 },
1191 ));
1192 assert!(!output.is_error);
1193 }
1194
1195 #[test]
1196 fn get_open_orders_tool_returns_result_with_symbol() {
1197 let tool = GetOpenOrdersTool::new(true);
1198 let runtime = tokio::runtime::Runtime::new().expect("runtime");
1199 let output = runtime.block_on(tool.call(
1200 serde_json::json!({"symbol": "BTC-PERP"}),
1201 &ToolContext {
1202 caller_id: "test".to_string(),
1203 platform: "test".to_string(),
1204 extra: Default::default(),
1205 },
1206 ));
1207 assert!(!output.is_error);
1208 }
1209}