Skip to main content

ares/mcp/
usage.rs

1// ares/src/mcp/usage.rs
2// Records every MCP tool call as a usage event.
3// This feeds into the same usage/billing system as HTTP API calls.
4
5use crate::types::AppError;
6use chrono::{Datelike, Utc};
7use uuid::Uuid;
8
9/// The type of MCP operation being tracked.
10#[derive(Debug, Clone, Copy)]
11pub enum McpOperation {
12    ListAgents,
13    RunAgent,
14    GetStatus,
15    DeployAgent,
16    GetUsage,
17    ErukaRead,
18    ErukaWrite,
19    ErukaSearch,
20}
21
22impl McpOperation {
23    /// Returns the operation name as stored in the database.
24    pub fn as_str(&self) -> &'static str {
25        match self {
26            Self::ListAgents => "mcp.ares_list_agents",
27            Self::RunAgent => "mcp.ares_run_agent",
28            Self::GetStatus => "mcp.ares_get_status",
29            Self::DeployAgent => "mcp.ares_deploy_agent",
30            Self::GetUsage => "mcp.ares_get_usage",
31            Self::ErukaRead => "mcp.eruka_read",
32            Self::ErukaWrite => "mcp.eruka_write",
33            Self::ErukaSearch => "mcp.eruka_search",
34        }
35    }
36
37    /// Returns the token cost weight for this operation.
38    /// Used for usage quota calculations.
39    /// - Read operations: 1 unit
40    /// - Write operations: 2 units
41    /// - Agent runs: 10 units (LLM call involved)
42    /// - Deploy: 5 units
43    pub fn token_weight(&self) -> u64 {
44        match self {
45            Self::ListAgents => 1,
46            Self::RunAgent => 10,
47            Self::GetStatus => 1,
48            Self::DeployAgent => 5,
49            Self::GetUsage => 1,
50            Self::ErukaRead => 1,
51            Self::ErukaWrite => 2,
52            Self::ErukaSearch => 2,
53        }
54    }
55}
56
57/// Records a single MCP usage event in the database.
58///
59/// # Arguments
60/// - `pool`: PostgreSQL connection pool
61/// - `tenant_id`: The tenant making the call
62/// - `operation`: Which MCP tool was called
63/// - `tokens_used`: Actual tokens consumed (0 for non-LLM calls, actual count for RunAgent)
64/// - `success`: Whether the call succeeded
65/// - `duration_ms`: How long the call took in milliseconds
66///
67/// # Errors
68/// Returns error if the database insert fails. The caller should
69/// log the error but NOT fail the tool call — usage tracking failure
70/// should not block the user's request.
71pub async fn record_mcp_usage(
72    pool: &sqlx::PgPool,
73    tenant_id: &str,
74    operation: McpOperation,
75    tokens_used: u64,
76    success: bool,
77    duration_ms: u64,
78) -> Result<(), AppError> {
79    let now_ts = Utc::now().timestamp();
80    let op_name = operation.as_str();
81    let weight = operation.token_weight();
82
83    // The effective_tokens is the larger of actual tokens and the weight minimum.
84    // This ensures that even non-LLM calls have a baseline cost.
85    let effective_tokens = std::cmp::max(tokens_used, weight);
86
87    // Insert into unified usage_events table (matches migrations/001_usage_events_unified.sql)
88    let result = sqlx::query(
89        r#"
90        INSERT INTO usage_events (
91            id, tenant_id, source, request_count, token_count,
92            operation, tokens_used, effective_tokens, success, duration_ms, created_at
93        )
94        VALUES ($1, $2, 'mcp', 1, $3, $4, $5, $6, $7, $8, $9)
95        "#,
96    )
97    .bind(Uuid::new_v4().to_string())
98    .bind(tenant_id)
99    .bind(effective_tokens as i64) // token_count = effective_tokens for quota tracking
100    .bind(op_name)
101    .bind(tokens_used as i64)
102    .bind(effective_tokens as i64)
103    .bind(success)
104    .bind(duration_ms as i64)
105    .bind(now_ts)
106    .execute(pool)
107    .await;
108
109    match result {
110        Ok(_) => {
111            tracing::debug!(
112                tenant_id = tenant_id,
113                operation = op_name,
114                tokens = effective_tokens,
115                success = success,
116                duration_ms = duration_ms,
117                "MCP usage event recorded"
118            );
119            Ok(())
120        }
121        Err(e) => {
122            tracing::error!(
123                error = %e,
124                tenant_id = tenant_id,
125                operation = op_name,
126                "Failed to record MCP usage event - continuing anyway"
127            );
128            // Don't fail the tool call - just log the error
129            Ok(())
130        }
131    }
132}
133
134/// Checks if the tenant has exceeded their usage quota.
135///
136/// # Returns
137/// - `Ok(true)` if the tenant is within their quota
138/// - `Ok(false)` if the tenant has exceeded their quota
139/// - `Err` if the database query fails
140pub async fn check_quota(
141    pool: &sqlx::PgPool,
142    tenant_id: &str,
143    tier: &str,
144) -> Result<bool, AppError> {
145    // Get the monthly quota for this tier
146    let max_tokens: i64 = match tier {
147        "free" => 10_000,
148        "dev" => 500_000,
149        "pro" => 5_000_000,
150        "enterprise" => i64::MAX, // unlimited for enterprise
151        _ => 10_000,              // default to free tier
152    };
153
154    // Sum effective_tokens for this month (created_at is a Unix BIGINT timestamp)
155    let now = Utc::now();
156    let start_of_month = now
157        .date_naive()
158        .with_day(1)
159        .unwrap()
160        .and_hms_opt(0, 0, 0)
161        .unwrap()
162        .and_utc()
163        .timestamp();
164
165    let row: (i64,) = sqlx::query_as(
166        r#"
167        SELECT COALESCE(SUM(effective_tokens)::bigint, 0)
168        FROM usage_events
169        WHERE tenant_id = $1 AND created_at >= $2
170        "#,
171    )
172    .bind(tenant_id)
173    .bind(start_of_month)
174    .fetch_one(pool)
175    .await
176    .map_err(|e| AppError::Database(format!("Failed to check quota: {}", e)))?;
177
178    let used = row.0;
179    let within_quota = used < max_tokens;
180
181    if !within_quota {
182        tracing::warn!(
183            tenant_id = tenant_id,
184            tier = tier,
185            used = used,
186            max = max_tokens,
187            "Tenant exceeded MCP usage quota"
188        );
189    }
190
191    Ok(within_quota)
192}