use std::sync::Arc;
use rmcp::model::Tool;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::schema::{parse_input as parse, schema_for};
use super::{ToolClass, ToolEntry, ToolHandlerFn, ToolRegistry};
use crate::context::AdapterContext;
use crate::error::AdapterError;
pub fn register(registry: &mut ToolRegistry) {
registry.insert(get_account_summary_tool());
registry.insert(get_positions_tool());
registry.insert(get_subaccounts_tool());
registry.insert(get_transaction_log_tool());
registry.insert(get_deposits_tool());
registry.insert(get_withdrawals_tool());
registry.insert(get_open_orders_by_currency_tool());
registry.insert(get_open_orders_by_instrument_tool());
registry.insert(get_user_trades_by_currency_tool());
registry.insert(get_user_trades_by_instrument_tool());
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct GetAccountSummaryInput {
pub currency: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extended: Option<bool>,
}
fn get_account_summary_tool() -> ToolEntry {
let schema = schema_for::<GetAccountSummaryInput>();
let descriptor = Tool::new(
"get_account_summary",
"Account balance / equity / margin for a single currency.",
schema,
);
let handler: ToolHandlerFn =
Arc::new(|ctx, input| Box::pin(handle_get_account_summary(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_account_summary(
ctx: &AdapterContext,
input: Value,
) -> Result<Value, AdapterError> {
let input: GetAccountSummaryInput = parse(input)?;
let result = ctx
.http
.get_account_summary(&input.currency, input.extended)
.await?;
Ok(serde_json::to_value(&result)?)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct GetPositionsInput {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub currency: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kind: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub subaccount_id: Option<i32>,
}
fn get_positions_tool() -> ToolEntry {
let schema = schema_for::<GetPositionsInput>();
let descriptor = Tool::new(
"get_positions",
"Open positions, optionally filtered by currency / kind / subaccount.",
schema,
);
let handler: ToolHandlerFn = Arc::new(|ctx, input| Box::pin(handle_get_positions(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_positions(ctx: &AdapterContext, input: Value) -> Result<Value, AdapterError> {
let input: GetPositionsInput = parse(input)?;
let result = ctx
.http
.get_positions(
input.currency.as_deref(),
input.kind.as_deref(),
input.subaccount_id,
)
.await?;
Ok(serde_json::to_value(&result)?)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct GetSubaccountsInput {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub with_portfolio: Option<bool>,
}
fn get_subaccounts_tool() -> ToolEntry {
let schema = schema_for::<GetSubaccountsInput>();
let descriptor = Tool::new(
"get_subaccounts",
"List subaccounts, optionally including portfolio per subaccount.",
schema,
);
let handler: ToolHandlerFn =
Arc::new(|ctx, input| Box::pin(handle_get_subaccounts(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_subaccounts(ctx: &AdapterContext, input: Value) -> Result<Value, AdapterError> {
let input: GetSubaccountsInput = parse(input)?;
let result = ctx.http.get_subaccounts(input.with_portfolio).await?;
Ok(serde_json::to_value(&result)?)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct GetTransactionLogInput {
pub currency: String,
pub start_timestamp: u64,
pub end_timestamp: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub query: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub count: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub subaccount_id: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub continuation: Option<u64>,
}
fn get_transaction_log_tool() -> ToolEntry {
let schema = schema_for::<GetTransactionLogInput>();
let descriptor = Tool::new(
"get_transaction_log",
"Account transaction log for a currency over a window, with optional pagination.",
schema,
);
let handler: ToolHandlerFn =
Arc::new(|ctx, input| Box::pin(handle_get_transaction_log(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_transaction_log(
ctx: &AdapterContext,
input: Value,
) -> Result<Value, AdapterError> {
let input: GetTransactionLogInput = parse(input)?;
let request = deribit_http::model::transaction::TransactionLogRequest {
currency: input.currency,
start_timestamp: input.start_timestamp,
end_timestamp: input.end_timestamp,
query: input.query,
count: input.count,
subaccount_id: input.subaccount_id,
continuation: input.continuation,
};
let result = ctx.http.get_transaction_log(request).await?;
Ok(serde_json::to_value(&result)?)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct PaginatedCurrencyInput {
pub currency: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub count: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub offset: Option<u32>,
}
fn get_deposits_tool() -> ToolEntry {
let schema = schema_for::<PaginatedCurrencyInput>();
let descriptor = Tool::new(
"get_deposits",
"Recent deposits for a currency, paginated.",
schema,
);
let handler: ToolHandlerFn = Arc::new(|ctx, input| Box::pin(handle_get_deposits(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_deposits(ctx: &AdapterContext, input: Value) -> Result<Value, AdapterError> {
let input: PaginatedCurrencyInput = parse(input)?;
let result = ctx
.http
.get_deposits(&input.currency, input.count, input.offset)
.await?;
Ok(serde_json::to_value(&result)?)
}
fn get_withdrawals_tool() -> ToolEntry {
let schema = schema_for::<PaginatedCurrencyInput>();
let descriptor = Tool::new(
"get_withdrawals",
"Recent withdrawals for a currency, paginated.",
schema,
);
let handler: ToolHandlerFn =
Arc::new(|ctx, input| Box::pin(handle_get_withdrawals(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_withdrawals(ctx: &AdapterContext, input: Value) -> Result<Value, AdapterError> {
let input: PaginatedCurrencyInput = parse(input)?;
let result = ctx
.http
.get_withdrawals(&input.currency, input.count, input.offset)
.await?;
Ok(serde_json::to_value(&result)?)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct GetOpenOrdersByCurrencyInput {
pub currency: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kind: Option<String>,
#[serde(default, rename = "type", skip_serializing_if = "Option::is_none")]
pub order_type: Option<String>,
}
fn get_open_orders_by_currency_tool() -> ToolEntry {
let schema = schema_for::<GetOpenOrdersByCurrencyInput>();
let descriptor = Tool::new(
"get_open_orders_by_currency",
"Open orders for a currency, optionally filtered by kind and type.",
schema,
);
let handler: ToolHandlerFn =
Arc::new(|ctx, input| Box::pin(handle_get_open_orders_by_currency(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_open_orders_by_currency(
ctx: &AdapterContext,
input: Value,
) -> Result<Value, AdapterError> {
let input: GetOpenOrdersByCurrencyInput = parse(input)?;
let result = ctx
.http
.get_open_orders_by_currency(
&input.currency,
input.kind.as_deref(),
input.order_type.as_deref(),
)
.await?;
Ok(serde_json::to_value(&result)?)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct GetOpenOrdersByInstrumentInput {
pub instrument_name: String,
#[serde(default, rename = "type", skip_serializing_if = "Option::is_none")]
pub order_type: Option<String>,
}
fn get_open_orders_by_instrument_tool() -> ToolEntry {
let schema = schema_for::<GetOpenOrdersByInstrumentInput>();
let descriptor = Tool::new(
"get_open_orders_by_instrument",
"Open orders for a single instrument, optionally filtered by type.",
schema,
);
let handler: ToolHandlerFn =
Arc::new(|ctx, input| Box::pin(handle_get_open_orders_by_instrument(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_open_orders_by_instrument(
ctx: &AdapterContext,
input: Value,
) -> Result<Value, AdapterError> {
let input: GetOpenOrdersByInstrumentInput = parse(input)?;
let result = ctx
.http
.get_open_orders_by_instrument(&input.instrument_name, input.order_type.as_deref())
.await?;
Ok(serde_json::to_value(&result)?)
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct GetUserTradesByCurrencyInput {
pub currency: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kind: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub start_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub end_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub count: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub start_timestamp: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub end_timestamp: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sorting: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub historical: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub subaccount_id: Option<u32>,
}
fn get_user_trades_by_currency_tool() -> ToolEntry {
let schema = schema_for::<GetUserTradesByCurrencyInput>();
let descriptor = Tool::new(
"get_user_trades_by_currency",
"User trades for a currency over an id / timestamp window with sort + historical opt-in.",
schema,
);
let handler: ToolHandlerFn =
Arc::new(|ctx, input| Box::pin(handle_get_user_trades_by_currency(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_user_trades_by_currency(
ctx: &AdapterContext,
input: Value,
) -> Result<Value, AdapterError> {
let input: GetUserTradesByCurrencyInput = parse(input)?;
if let Some(count) = input.count {
validate_count_range(count)?;
}
let currency = parse_currency(&input.currency)?;
let kind = input.kind.as_deref().map(parse_kind).transpose()?;
let sorting = input.sorting.as_deref().map(parse_sorting).transpose()?;
let request = deribit_http::model::request::trade::TradesRequest {
currency,
kind,
start_id: input.start_id,
end_id: input.end_id,
count: input.count,
start_timestamp: input.start_timestamp,
end_timestamp: input.end_timestamp,
sorting,
historical: input.historical,
subaccount_id: input.subaccount_id,
};
let result = ctx.http.get_user_trades_by_currency(request).await?;
Ok(serde_json::to_value(&result)?)
}
fn parse_currency(s: &str) -> Result<deribit_http::model::Currency, AdapterError> {
serde_json::from_value(serde_json::Value::String(s.to_uppercase())).map_err(|err| {
AdapterError::Validation {
field: "currency".to_string(),
message: err.to_string(),
}
})
}
fn parse_kind(s: &str) -> Result<deribit_http::model::InstrumentKind, AdapterError> {
serde_json::from_value(serde_json::Value::String(s.to_lowercase())).map_err(|err| {
AdapterError::Validation {
field: "kind".to_string(),
message: err.to_string(),
}
})
}
fn parse_sorting(s: &str) -> Result<deribit_http::model::SortDirection, AdapterError> {
serde_json::from_value(serde_json::Value::String(s.to_lowercase())).map_err(|err| {
AdapterError::Validation {
field: "sorting".to_string(),
message: err.to_string(),
}
})
}
fn validate_count_range(count: u32) -> Result<(), AdapterError> {
if (1..=1000).contains(&count) {
Ok(())
} else {
Err(AdapterError::Validation {
field: "count".to_string(),
message: format!("expected 1..=1000, got {count}"),
})
}
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct GetUserTradesByInstrumentInput {
pub instrument_name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub start_seq: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub end_seq: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub count: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub include_old: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sorting: Option<String>,
}
fn get_user_trades_by_instrument_tool() -> ToolEntry {
let schema = schema_for::<GetUserTradesByInstrumentInput>();
let descriptor = Tool::new(
"get_user_trades_by_instrument",
"User trades for a single instrument over a sequence-number window.",
schema,
);
let handler: ToolHandlerFn =
Arc::new(|ctx, input| Box::pin(handle_get_user_trades_by_instrument(ctx, input)));
ToolEntry {
descriptor,
class: ToolClass::Account,
handler,
}
}
async fn handle_get_user_trades_by_instrument(
ctx: &AdapterContext,
input: Value,
) -> Result<Value, AdapterError> {
let input: GetUserTradesByInstrumentInput = parse(input)?;
if let Some(count) = input.count {
validate_count_range(count)?;
}
let result = ctx
.http
.get_user_trades_by_instrument(
&input.instrument_name,
input.start_seq,
input.end_seq,
input.count,
input.include_old,
input.sorting.as_deref(),
)
.await?;
Ok(serde_json::to_value(&result)?)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_account_tools_register_under_account_class() {
for entry in [
get_account_summary_tool(),
get_positions_tool(),
get_subaccounts_tool(),
get_transaction_log_tool(),
get_deposits_tool(),
get_withdrawals_tool(),
get_open_orders_by_currency_tool(),
get_open_orders_by_instrument_tool(),
get_user_trades_by_currency_tool(),
get_user_trades_by_instrument_tool(),
] {
assert_eq!(entry.class, ToolClass::Account);
}
}
#[test]
fn register_populates_full_account_set() {
let mut registry = ToolRegistry::new();
register(&mut registry);
let listed = registry.list();
let names: Vec<&str> = listed.iter().map(|t| t.name.as_ref()).collect();
for expected in [
"get_account_summary",
"get_deposits",
"get_open_orders_by_currency",
"get_open_orders_by_instrument",
"get_positions",
"get_subaccounts",
"get_transaction_log",
"get_user_trades_by_currency",
"get_user_trades_by_instrument",
"get_withdrawals",
] {
assert!(
names.contains(&expected),
"missing tool {expected}; got {names:?}"
);
}
assert_eq!(registry.len(), 10);
}
#[test]
fn open_orders_by_currency_input_renames_type_field() {
let v = serde_json::json!({"currency": "BTC", "type": "limit"});
let parsed: GetOpenOrdersByCurrencyInput = serde_json::from_value(v).expect("parse");
assert_eq!(parsed.order_type.as_deref(), Some("limit"));
}
#[test]
fn user_trades_by_instrument_input_accepts_required_only() {
let v = serde_json::json!({"instrument_name": "BTC-PERPETUAL"});
let parsed: GetUserTradesByInstrumentInput = serde_json::from_value(v).expect("parse");
assert!(parsed.start_seq.is_none());
assert!(parsed.end_seq.is_none());
}
#[test]
fn transaction_log_input_requires_window() {
let err =
parse::<GetTransactionLogInput>(serde_json::json!({"currency": "BTC"})).unwrap_err();
assert!(matches!(err, AdapterError::Validation { .. }));
}
#[test]
fn paginated_input_accepts_required_only() {
let parsed: PaginatedCurrencyInput =
serde_json::from_value(serde_json::json!({"currency": "BTC"})).expect("parse");
assert!(parsed.count.is_none());
assert!(parsed.offset.is_none());
}
#[test]
fn account_summary_input_requires_currency() {
let err = parse::<GetAccountSummaryInput>(serde_json::json!({})).unwrap_err();
match err {
AdapterError::Validation { field, .. } => assert_eq!(field, "arguments"),
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn positions_input_accepts_no_filters() {
let parsed: GetPositionsInput =
serde_json::from_value(serde_json::json!({})).expect("parse");
assert!(parsed.currency.is_none());
assert!(parsed.kind.is_none());
assert!(parsed.subaccount_id.is_none());
}
#[test]
fn subaccounts_input_accepts_no_arguments() {
let parsed: GetSubaccountsInput =
serde_json::from_value(serde_json::json!({})).expect("parse");
assert!(parsed.with_portfolio.is_none());
}
#[test]
fn parse_currency_rejects_out_of_vocab() {
let err = parse_currency("DOGE").unwrap_err();
match err {
AdapterError::Validation { field, .. } => assert_eq!(field, "currency"),
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn parse_kind_rejects_out_of_vocab() {
let err = parse_kind("perpetual").unwrap_err();
match err {
AdapterError::Validation { field, .. } => assert_eq!(field, "kind"),
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn parse_sorting_rejects_out_of_vocab() {
let err = parse_sorting("random").unwrap_err();
match err {
AdapterError::Validation { field, .. } => assert_eq!(field, "sorting"),
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn parse_currency_accepts_lowercase() {
let parsed = parse_currency("btc").expect("ok");
assert!(matches!(parsed, deribit_http::model::Currency::Btc));
}
#[test]
fn validate_count_rejects_zero_and_over_1000() {
assert!(matches!(
validate_count_range(0).unwrap_err(),
AdapterError::Validation { ref field, .. } if field == "count"
));
assert!(matches!(
validate_count_range(1001).unwrap_err(),
AdapterError::Validation { ref field, .. } if field == "count"
));
}
#[test]
fn validate_count_accepts_boundary_values() {
assert!(validate_count_range(1).is_ok());
assert!(validate_count_range(1000).is_ok());
}
}