1use std::collections::HashMap;
21use std::future::Future;
22use std::pin::Pin;
23use std::sync::Arc;
24
25use rmcp::model::Tool;
26use serde_json::Value;
27
28use crate::context::AdapterContext;
29use crate::error::AdapterError;
30
31pub mod account;
32pub mod public;
33pub mod schema;
34pub mod trading;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
46#[non_exhaustive]
47pub enum ToolClass {
48 Read,
50 Account,
52 Trading,
54}
55
56impl ToolClass {
57 #[must_use]
61 pub const fn flag(self) -> &'static str {
62 match self {
63 Self::Read => "(always enabled)",
64 Self::Account => "DERIBIT_CLIENT_ID + DERIBIT_CLIENT_SECRET",
65 Self::Trading => "--allow-trading",
66 }
67 }
68}
69
70pub type ToolFuture<'a> = Pin<Box<dyn Future<Output = Result<Value, AdapterError>> + Send + 'a>>;
75
76pub type ToolHandlerFn =
79 Arc<dyn for<'a> Fn(&'a AdapterContext, Value) -> ToolFuture<'a> + Send + Sync + 'static>;
80
81#[derive(Clone)]
88pub struct ToolEntry {
89 pub(crate) descriptor: Tool,
92 pub(crate) class: ToolClass,
94 pub(crate) handler: ToolHandlerFn,
96}
97
98impl ToolEntry {
99 #[must_use]
101 pub fn descriptor(&self) -> &Tool {
102 &self.descriptor
103 }
104
105 #[must_use]
107 pub fn class(&self) -> ToolClass {
108 self.class
109 }
110}
111
112impl std::fmt::Debug for ToolEntry {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("ToolEntry")
115 .field("descriptor", &self.descriptor)
116 .field("class", &self.class)
117 .field("handler", &"<dyn Fn>")
118 .finish()
119 }
120}
121
122#[derive(Debug, Default, Clone)]
127pub struct ToolRegistry {
128 entries: HashMap<String, ToolEntry>,
129}
130
131impl ToolRegistry {
132 #[must_use]
134 pub fn new() -> Self {
135 Self::default()
136 }
137
138 #[must_use]
152 pub fn build(ctx: &AdapterContext) -> Self {
153 let mut registry = Self::new();
154 public::register(&mut registry);
155 if ctx.has_credentials() {
156 account::register(&mut registry);
157 }
158 if ctx.has_credentials() && ctx.config.allow_trading {
159 trading::register(&mut registry);
160 }
161 registry
162 }
163
164 #[allow(dead_code)]
177 pub(crate) fn insert(&mut self, entry: ToolEntry) -> Option<ToolEntry> {
178 let name = entry.descriptor.name.to_string();
179 self.entries.insert(name, entry)
180 }
181
182 #[must_use]
184 pub fn list(&self) -> Vec<Tool> {
185 let mut tools: Vec<Tool> = self
186 .entries
187 .values()
188 .map(|e| e.descriptor.clone())
189 .collect();
190 tools.sort_by(|a, b| a.name.cmp(&b.name));
191 tools
192 }
193
194 #[must_use]
196 pub fn len(&self) -> usize {
197 self.entries.len()
198 }
199
200 #[must_use]
202 pub fn is_empty(&self) -> bool {
203 self.entries.is_empty()
204 }
205
206 #[must_use]
208 pub fn get(&self, name: &str) -> Option<&ToolEntry> {
209 self.entries.get(name)
210 }
211
212 #[must_use]
214 pub fn contains(&self, name: &str) -> bool {
215 self.entries.contains_key(name)
216 }
217
218 pub async fn call(
233 &self,
234 ctx: &AdapterContext,
235 name: &str,
236 input: Value,
237 ) -> Result<Value, AdapterError> {
238 let entry = self
239 .get(name)
240 .ok_or_else(|| AdapterError::validation("name", format!("unknown tool: `{name}`")))?;
241
242 check_class_enabled(entry.class, ctx, &entry.descriptor.name)?;
243
244 (entry.handler)(ctx, input).await
245 }
246}
247
248#[inline(never)]
256fn check_class_enabled(
257 class: ToolClass,
258 ctx: &AdapterContext,
259 name: &str,
260) -> Result<(), AdapterError> {
261 match class {
262 ToolClass::Read => Ok(()),
263 ToolClass::Account => {
264 if ctx.has_credentials() {
265 Ok(())
266 } else {
267 Err(AdapterError::NotEnabled {
268 tool: name.to_string(),
269 flag: ToolClass::Account.flag().to_string(),
270 })
271 }
272 }
273 ToolClass::Trading => {
274 let creds = ctx.has_credentials();
275 let trading = ctx.config.allow_trading;
276 if creds && trading {
277 return Ok(());
278 }
279 let flag = match (creds, trading) {
280 (false, false) => "DERIBIT_CLIENT_ID + DERIBIT_CLIENT_SECRET + --allow-trading",
281 (false, true) => ToolClass::Account.flag(),
282 (true, false) => "--allow-trading",
283 (true, true) => unreachable!("returned Ok above"),
284 };
285 Err(AdapterError::NotEnabled {
286 tool: name.to_string(),
287 flag: flag.to_string(),
288 })
289 }
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::config::{Config, LogFormat, OrderTransport, Transport};
297 use rmcp::model::Tool;
298 use serde_json::json;
299 use std::net::SocketAddr;
300 use std::sync::Arc;
301
302 fn cfg(with_creds: bool, allow_trading: bool) -> Config {
303 Config {
304 endpoint: "https://test.deribit.com".to_string(),
305 client_id: with_creds.then(|| "id".to_string()),
306 client_secret: with_creds.then(|| "secret".to_string()),
307 allow_trading,
308 max_order_usd: None,
309 transport: Transport::Stdio,
310 http_listen: SocketAddr::from(([127, 0, 0, 1], 8723)),
311 http_bearer_token: None,
312 log_format: LogFormat::Text,
313 order_transport: OrderTransport::Http,
314 }
315 }
316
317 fn ctx(with_creds: bool, allow_trading: bool) -> AdapterContext {
318 AdapterContext::new(Arc::new(cfg(with_creds, allow_trading))).expect("ctx")
319 }
320
321 fn empty_schema() -> Arc<serde_json::Map<String, Value>> {
322 Arc::new(serde_json::Map::new())
323 }
324
325 fn make_entry(name: &'static str, class: ToolClass) -> ToolEntry {
326 let descriptor = Tool::new(
327 std::borrow::Cow::Borrowed(name),
328 "test tool",
329 empty_schema(),
330 );
331 let handler: ToolHandlerFn =
332 Arc::new(|_ctx, _input| Box::pin(async move { Ok(json!({"ok": true})) }));
333 ToolEntry {
334 descriptor,
335 class,
336 handler,
337 }
338 }
339
340 #[test]
341 fn class_flags_match_documentation() {
342 assert_eq!(ToolClass::Read.flag(), "(always enabled)");
343 assert_eq!(
344 ToolClass::Account.flag(),
345 "DERIBIT_CLIENT_ID + DERIBIT_CLIENT_SECRET"
346 );
347 assert_eq!(ToolClass::Trading.flag(), "--allow-trading");
348 }
349
350 #[test]
351 fn registry_starts_empty() {
352 let r = ToolRegistry::new();
353 assert!(r.is_empty());
354 assert_eq!(r.len(), 0);
355 assert!(r.list().is_empty());
356 }
357
358 #[test]
359 fn registry_lists_sorted_by_name() {
360 let mut r = ToolRegistry::new();
361 r.insert(make_entry("get_ticker", ToolClass::Read));
362 r.insert(make_entry("get_book", ToolClass::Read));
363 let listed = r.list();
364 let names: Vec<&str> = listed.iter().map(|t| t.name.as_ref()).collect();
365 assert_eq!(names, vec!["get_book", "get_ticker"]);
366 }
367
368 #[test]
369 fn build_without_creds_includes_only_read() {
370 let registry = ToolRegistry::build(&ctx(false, false));
371 assert_eq!(registry.len(), 14);
376 for tool in registry.list() {
377 let entry = registry.get(tool.name.as_ref()).expect("entry");
378 assert_eq!(entry.class, ToolClass::Read, "{}", tool.name);
379 }
380 }
381
382 #[tokio::test]
383 async fn dispatch_unknown_tool_returns_validation() {
384 let registry = ToolRegistry::new();
385 let ctx = ctx(false, false);
386 let err = registry
387 .call(&ctx, "no_such_tool", Value::Null)
388 .await
389 .unwrap_err();
390 match err {
391 AdapterError::Validation { field, .. } => assert_eq!(field, "name"),
392 other => panic!("unexpected: {other:?}"),
393 }
394 }
395
396 #[tokio::test]
397 async fn dispatch_read_class_succeeds_without_creds() {
398 let mut registry = ToolRegistry::new();
399 registry.insert(make_entry("ping", ToolClass::Read));
400 let ctx = ctx(false, false);
401 let out = registry.call(&ctx, "ping", Value::Null).await.expect("ok");
402 assert_eq!(out, json!({"ok": true}));
403 }
404
405 #[tokio::test]
406 async fn dispatch_account_class_requires_credentials() {
407 let mut registry = ToolRegistry::new();
408 registry.insert(make_entry("get_account_summary", ToolClass::Account));
409 let ctx = ctx(false, false);
410 let err = registry
411 .call(&ctx, "get_account_summary", Value::Null)
412 .await
413 .unwrap_err();
414 match err {
415 AdapterError::NotEnabled { tool, flag } => {
416 assert_eq!(tool, "get_account_summary");
417 assert_eq!(flag, ToolClass::Account.flag());
418 }
419 other => panic!("unexpected: {other:?}"),
420 }
421 }
422
423 #[tokio::test]
424 async fn dispatch_account_class_succeeds_with_credentials() {
425 let mut registry = ToolRegistry::new();
426 registry.insert(make_entry("get_account_summary", ToolClass::Account));
427 let ctx = ctx(true, false);
428 registry
429 .call(&ctx, "get_account_summary", Value::Null)
430 .await
431 .expect("ok");
432 }
433
434 #[tokio::test]
435 async fn dispatch_trading_class_requires_allow_trading_flag() {
436 let mut registry = ToolRegistry::new();
437 registry.insert(make_entry("place_order", ToolClass::Trading));
438 let ctx = ctx(true, false);
439 let err = registry
440 .call(&ctx, "place_order", Value::Null)
441 .await
442 .unwrap_err();
443 match err {
444 AdapterError::NotEnabled { tool, flag } => {
445 assert_eq!(tool, "place_order");
446 assert_eq!(flag, "--allow-trading");
447 }
448 other => panic!("unexpected: {other:?}"),
449 }
450 }
451
452 #[tokio::test]
453 async fn dispatch_trading_class_succeeds_with_creds_and_flag() {
454 let mut registry = ToolRegistry::new();
455 registry.insert(make_entry("place_order", ToolClass::Trading));
456 let ctx = ctx(true, true);
457 registry
458 .call(&ctx, "place_order", Value::Null)
459 .await
460 .expect("ok");
461 }
462}