use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use rmcp::model::Tool;
use serde_json::Value;
use crate::context::AdapterContext;
use crate::error::AdapterError;
pub mod account;
pub mod public;
pub mod schema;
pub mod trading;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[non_exhaustive]
pub enum ToolClass {
Read,
Account,
Trading,
}
impl ToolClass {
#[must_use]
pub const fn flag(self) -> &'static str {
match self {
Self::Read => "(always enabled)",
Self::Account => "DERIBIT_CLIENT_ID + DERIBIT_CLIENT_SECRET",
Self::Trading => "--allow-trading",
}
}
}
pub type ToolFuture<'a> = Pin<Box<dyn Future<Output = Result<Value, AdapterError>> + Send + 'a>>;
pub type ToolHandlerFn =
Arc<dyn for<'a> Fn(&'a AdapterContext, Value) -> ToolFuture<'a> + Send + Sync + 'static>;
#[derive(Clone)]
pub struct ToolEntry {
pub(crate) descriptor: Tool,
pub(crate) class: ToolClass,
pub(crate) handler: ToolHandlerFn,
}
impl ToolEntry {
#[must_use]
pub fn descriptor(&self) -> &Tool {
&self.descriptor
}
#[must_use]
pub fn class(&self) -> ToolClass {
self.class
}
}
impl std::fmt::Debug for ToolEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolEntry")
.field("descriptor", &self.descriptor)
.field("class", &self.class)
.field("handler", &"<dyn Fn>")
.finish()
}
}
#[derive(Debug, Default, Clone)]
pub struct ToolRegistry {
entries: HashMap<String, ToolEntry>,
}
impl ToolRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn build(ctx: &AdapterContext) -> Self {
let mut registry = Self::new();
public::register(&mut registry);
if ctx.has_credentials() {
account::register(&mut registry);
}
if ctx.has_credentials() && ctx.config.allow_trading {
trading::register(&mut registry);
}
registry
}
#[allow(dead_code)]
pub(crate) fn insert(&mut self, entry: ToolEntry) -> Option<ToolEntry> {
let name = entry.descriptor.name.to_string();
self.entries.insert(name, entry)
}
#[must_use]
pub fn list(&self) -> Vec<Tool> {
let mut tools: Vec<Tool> = self
.entries
.values()
.map(|e| e.descriptor.clone())
.collect();
tools.sort_by(|a, b| a.name.cmp(&b.name));
tools
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&ToolEntry> {
self.entries.get(name)
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
self.entries.contains_key(name)
}
pub async fn call(
&self,
ctx: &AdapterContext,
name: &str,
input: Value,
) -> Result<Value, AdapterError> {
let entry = self
.get(name)
.ok_or_else(|| AdapterError::validation("name", format!("unknown tool: `{name}`")))?;
check_class_enabled(entry.class, ctx, &entry.descriptor.name)?;
(entry.handler)(ctx, input).await
}
}
#[inline(never)]
fn check_class_enabled(
class: ToolClass,
ctx: &AdapterContext,
name: &str,
) -> Result<(), AdapterError> {
match class {
ToolClass::Read => Ok(()),
ToolClass::Account => {
if ctx.has_credentials() {
Ok(())
} else {
Err(AdapterError::NotEnabled {
tool: name.to_string(),
flag: ToolClass::Account.flag().to_string(),
})
}
}
ToolClass::Trading => {
let creds = ctx.has_credentials();
let trading = ctx.config.allow_trading;
if creds && trading {
return Ok(());
}
let flag = match (creds, trading) {
(false, false) => "DERIBIT_CLIENT_ID + DERIBIT_CLIENT_SECRET + --allow-trading",
(false, true) => ToolClass::Account.flag(),
(true, false) => "--allow-trading",
(true, true) => unreachable!("returned Ok above"),
};
Err(AdapterError::NotEnabled {
tool: name.to_string(),
flag: flag.to_string(),
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{Config, LogFormat, OrderTransport, Transport};
use rmcp::model::Tool;
use serde_json::json;
use std::net::SocketAddr;
use std::sync::Arc;
fn cfg(with_creds: bool, allow_trading: bool) -> Config {
Config {
endpoint: "https://test.deribit.com".to_string(),
client_id: with_creds.then(|| "id".to_string()),
client_secret: with_creds.then(|| "secret".to_string()),
allow_trading,
max_order_usd: None,
transport: Transport::Stdio,
http_listen: SocketAddr::from(([127, 0, 0, 1], 8723)),
http_bearer_token: None,
log_format: LogFormat::Text,
order_transport: OrderTransport::Http,
}
}
fn ctx(with_creds: bool, allow_trading: bool) -> AdapterContext {
AdapterContext::new(Arc::new(cfg(with_creds, allow_trading))).expect("ctx")
}
fn empty_schema() -> Arc<serde_json::Map<String, Value>> {
Arc::new(serde_json::Map::new())
}
fn make_entry(name: &'static str, class: ToolClass) -> ToolEntry {
let descriptor = Tool::new(
std::borrow::Cow::Borrowed(name),
"test tool",
empty_schema(),
);
let handler: ToolHandlerFn =
Arc::new(|_ctx, _input| Box::pin(async move { Ok(json!({"ok": true})) }));
ToolEntry {
descriptor,
class,
handler,
}
}
#[test]
fn class_flags_match_documentation() {
assert_eq!(ToolClass::Read.flag(), "(always enabled)");
assert_eq!(
ToolClass::Account.flag(),
"DERIBIT_CLIENT_ID + DERIBIT_CLIENT_SECRET"
);
assert_eq!(ToolClass::Trading.flag(), "--allow-trading");
}
#[test]
fn registry_starts_empty() {
let r = ToolRegistry::new();
assert!(r.is_empty());
assert_eq!(r.len(), 0);
assert!(r.list().is_empty());
}
#[test]
fn registry_lists_sorted_by_name() {
let mut r = ToolRegistry::new();
r.insert(make_entry("get_ticker", ToolClass::Read));
r.insert(make_entry("get_book", ToolClass::Read));
let listed = r.list();
let names: Vec<&str> = listed.iter().map(|t| t.name.as_ref()).collect();
assert_eq!(names, vec!["get_book", "get_ticker"]);
}
#[test]
fn build_without_creds_includes_only_read() {
let registry = ToolRegistry::build(&ctx(false, false));
assert_eq!(registry.len(), 14);
for tool in registry.list() {
let entry = registry.get(tool.name.as_ref()).expect("entry");
assert_eq!(entry.class, ToolClass::Read, "{}", tool.name);
}
}
#[tokio::test]
async fn dispatch_unknown_tool_returns_validation() {
let registry = ToolRegistry::new();
let ctx = ctx(false, false);
let err = registry
.call(&ctx, "no_such_tool", Value::Null)
.await
.unwrap_err();
match err {
AdapterError::Validation { field, .. } => assert_eq!(field, "name"),
other => panic!("unexpected: {other:?}"),
}
}
#[tokio::test]
async fn dispatch_read_class_succeeds_without_creds() {
let mut registry = ToolRegistry::new();
registry.insert(make_entry("ping", ToolClass::Read));
let ctx = ctx(false, false);
let out = registry.call(&ctx, "ping", Value::Null).await.expect("ok");
assert_eq!(out, json!({"ok": true}));
}
#[tokio::test]
async fn dispatch_account_class_requires_credentials() {
let mut registry = ToolRegistry::new();
registry.insert(make_entry("get_account_summary", ToolClass::Account));
let ctx = ctx(false, false);
let err = registry
.call(&ctx, "get_account_summary", Value::Null)
.await
.unwrap_err();
match err {
AdapterError::NotEnabled { tool, flag } => {
assert_eq!(tool, "get_account_summary");
assert_eq!(flag, ToolClass::Account.flag());
}
other => panic!("unexpected: {other:?}"),
}
}
#[tokio::test]
async fn dispatch_account_class_succeeds_with_credentials() {
let mut registry = ToolRegistry::new();
registry.insert(make_entry("get_account_summary", ToolClass::Account));
let ctx = ctx(true, false);
registry
.call(&ctx, "get_account_summary", Value::Null)
.await
.expect("ok");
}
#[tokio::test]
async fn dispatch_trading_class_requires_allow_trading_flag() {
let mut registry = ToolRegistry::new();
registry.insert(make_entry("place_order", ToolClass::Trading));
let ctx = ctx(true, false);
let err = registry
.call(&ctx, "place_order", Value::Null)
.await
.unwrap_err();
match err {
AdapterError::NotEnabled { tool, flag } => {
assert_eq!(tool, "place_order");
assert_eq!(flag, "--allow-trading");
}
other => panic!("unexpected: {other:?}"),
}
}
#[tokio::test]
async fn dispatch_trading_class_succeeds_with_creds_and_flag() {
let mut registry = ToolRegistry::new();
registry.insert(make_entry("place_order", ToolClass::Trading));
let ctx = ctx(true, true);
registry
.call(&ctx, "place_order", Value::Null)
.await
.expect("ok");
}
}