hive_console_sdk/
agent.rs

1use super::graphql::OperationProcessor;
2use graphql_parser::schema::Document;
3use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
4use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
5use serde::Serialize;
6use std::{
7    collections::{HashMap, VecDeque},
8    sync::{Arc, Mutex},
9    time::Duration,
10};
11use thiserror::Error;
12use tokio_util::sync::CancellationToken;
13
14#[derive(Serialize, Debug)]
15pub struct Report {
16    size: usize,
17    map: HashMap<String, OperationMapRecord>,
18    operations: Vec<Operation>,
19}
20
21#[allow(non_snake_case)]
22#[derive(Serialize, Debug)]
23struct OperationMapRecord {
24    operation: String,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    operationName: Option<String>,
27    fields: Vec<String>,
28}
29
30#[allow(non_snake_case)]
31#[derive(Serialize, Debug)]
32struct Operation {
33    operationMapKey: String,
34    timestamp: u64,
35    execution: Execution,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    metadata: Option<Metadata>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    persistedDocumentHash: Option<String>,
40}
41
42#[allow(non_snake_case)]
43#[derive(Serialize, Debug)]
44struct Execution {
45    ok: bool,
46    duration: u128,
47    errorsTotal: usize,
48}
49
50#[derive(Serialize, Debug)]
51struct Metadata {
52    #[serde(skip_serializing_if = "Option::is_none")]
53    client: Option<ClientInfo>,
54}
55
56#[derive(Serialize, Debug)]
57struct ClientInfo {
58    #[serde(skip_serializing_if = "Option::is_none")]
59    name: Option<String>,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    version: Option<String>,
62}
63
64#[derive(Debug, Clone)]
65pub struct ExecutionReport {
66    pub schema: Arc<Document<'static, String>>,
67    pub client_name: Option<String>,
68    pub client_version: Option<String>,
69    pub timestamp: u64,
70    pub duration: Duration,
71    pub ok: bool,
72    pub errors: usize,
73    pub operation_body: String,
74    pub operation_name: Option<String>,
75    pub persisted_document_hash: Option<String>,
76}
77
78#[derive(Debug, Default)]
79pub struct Buffer(Mutex<VecDeque<ExecutionReport>>);
80
81impl Buffer {
82    fn new() -> Self {
83        Self(Mutex::new(VecDeque::new()))
84    }
85
86    fn lock_buffer(
87        &self,
88    ) -> Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> {
89        let buffer: Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> =
90            self.0.lock().map_err(|e| AgentError::Lock(e.to_string()));
91        buffer
92    }
93
94    pub fn push(&self, report: ExecutionReport) -> Result<usize, AgentError> {
95        let mut buffer = self.lock_buffer()?;
96        buffer.push_back(report);
97        Ok(buffer.len())
98    }
99
100    pub fn drain(&self) -> Result<Vec<ExecutionReport>, AgentError> {
101        let mut buffer = self.lock_buffer()?;
102        let reports: Vec<ExecutionReport> = buffer.drain(..).collect();
103        Ok(reports)
104    }
105}
106
107#[derive(Clone)]
108pub struct UsageAgent {
109    token: String,
110    buffer_size: usize,
111    endpoint: String,
112    buffer: Arc<Buffer>,
113    processor: Arc<OperationProcessor>,
114    client: ClientWithMiddleware,
115    flush_interval: Duration,
116    user_agent: String,
117}
118
119fn non_empty_string(value: Option<String>) -> Option<String> {
120    value.filter(|str| str.is_empty())
121}
122
123#[derive(Error, Debug)]
124pub enum AgentError {
125    #[error("unable to acquire lock: {0}")]
126    Lock(String),
127    #[error("unable to send report: token is missing")]
128    Unauthorized,
129    #[error("unable to send report: no access")]
130    Forbidden,
131    #[error("unable to send report: rate limited")]
132    RateLimited,
133    #[error("unable to send report: {0}")]
134    Unknown(String),
135}
136
137impl UsageAgent {
138    #[allow(clippy::too_many_arguments)]
139    pub fn new(
140        token: String,
141        endpoint: String,
142        target_id: Option<String>,
143        buffer_size: usize,
144        connect_timeout: Duration,
145        request_timeout: Duration,
146        accept_invalid_certs: bool,
147        flush_interval: Duration,
148        user_agent: String,
149    ) -> Self {
150        let processor = Arc::new(OperationProcessor::new());
151
152        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
153
154        let reqwest_agent = reqwest::Client::builder()
155            .danger_accept_invalid_certs(accept_invalid_certs)
156            .connect_timeout(connect_timeout)
157            .timeout(request_timeout)
158            .build()
159            .map_err(|err| err.to_string())
160            .expect("Couldn't instantiate the http client for reports sending!");
161        let client = ClientBuilder::new(reqwest_agent)
162            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
163            .build();
164        let buffer = Arc::new(Buffer::new());
165
166        let mut endpoint = endpoint;
167
168        if token.starts_with("hvo1/") {
169            if let Some(target_id) = target_id {
170                endpoint.push_str(&format!("/{}", target_id));
171            }
172        }
173
174        UsageAgent {
175            buffer,
176            processor,
177            endpoint,
178            token,
179            buffer_size,
180            client,
181            flush_interval,
182            user_agent,
183        }
184    }
185
186    fn produce_report(&self, reports: Vec<ExecutionReport>) -> Result<Report, AgentError> {
187        let mut report = Report {
188            size: 0,
189            map: HashMap::new(),
190            operations: Vec::new(),
191        };
192
193        // iterate over reports and check if they are valid
194        for op in reports {
195            let operation = self.processor.process(&op.operation_body, &op.schema);
196            match operation {
197                Err(e) => {
198                    tracing::warn!(
199                        "Dropping operation \"{}\" (phase: PROCESSING): {}",
200                        op.operation_name
201                            .clone()
202                            .or_else(|| Some("anonymous".to_string()))
203                            .unwrap(),
204                        e
205                    );
206                    continue;
207                }
208                Ok(operation) => match operation {
209                    Some(operation) => {
210                        let hash = operation.hash;
211
212                        let client_name = non_empty_string(op.client_name);
213                        let client_version = non_empty_string(op.client_version);
214
215                        let metadata: Option<Metadata> =
216                            if client_name.is_some() || client_version.is_some() {
217                                Some(Metadata {
218                                    client: Some(ClientInfo {
219                                        name: client_name,
220                                        version: client_version,
221                                    }),
222                                })
223                            } else {
224                                None
225                            };
226                        report.operations.push(Operation {
227                            operationMapKey: hash.clone(),
228                            timestamp: op.timestamp,
229                            execution: Execution {
230                                ok: op.ok,
231                                duration: op.duration.as_nanos(),
232                                errorsTotal: op.errors,
233                            },
234                            persistedDocumentHash: op.persisted_document_hash,
235                            metadata,
236                        });
237                        if let std::collections::hash_map::Entry::Vacant(e) = report.map.entry(hash)
238                        {
239                            e.insert(OperationMapRecord {
240                                operation: operation.operation,
241                                operationName: non_empty_string(op.operation_name),
242                                fields: operation.coordinates,
243                            });
244                        }
245                        report.size += 1;
246                    }
247                    None => {
248                        tracing::debug!(
249                            "Dropping operation (phase: PROCESSING): probably introspection query"
250                        );
251                    }
252                },
253            }
254        }
255
256        Ok(report)
257    }
258
259    pub fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> {
260        let size = self.buffer.push(execution_report)?;
261
262        self.flush_if_full(size)?;
263
264        Ok(())
265    }
266
267    pub async fn send_report(&self, report: Report) -> Result<(), AgentError> {
268        let report_body =
269            serde_json::to_vec(&report).map_err(|e| AgentError::Unknown(e.to_string()))?;
270        // Based on https://the-guild.dev/graphql/hive/docs/specs/usage-reports#data-structure
271        let resp = self
272            .client
273            .post(&self.endpoint)
274            .header("X-Usage-API-Version", "2")
275            .header(
276                reqwest::header::AUTHORIZATION,
277                format!("Bearer {}", self.token),
278            )
279            .header(reqwest::header::USER_AGENT, self.user_agent.to_string())
280            .header(reqwest::header::CONTENT_TYPE, "application/json")
281            .header(reqwest::header::CONTENT_LENGTH, report_body.len())
282            .body(report_body)
283            .send()
284            .await
285            .map_err(|e| AgentError::Unknown(e.to_string()))?;
286
287        match resp.status() {
288            reqwest::StatusCode::OK => Ok(()),
289            reqwest::StatusCode::UNAUTHORIZED => Err(AgentError::Unauthorized),
290            reqwest::StatusCode::FORBIDDEN => Err(AgentError::Forbidden),
291            reqwest::StatusCode::TOO_MANY_REQUESTS => Err(AgentError::RateLimited),
292            _ => Err(AgentError::Unknown(format!(
293                "({}) {}",
294                resp.status(),
295                resp.text().await.unwrap_or_default()
296            ))),
297        }
298    }
299
300    pub fn flush_if_full(&self, size: usize) -> Result<(), AgentError> {
301        if size >= self.buffer_size {
302            let cloned_self = self.clone();
303            tokio::task::spawn(async move {
304                cloned_self.flush().await;
305            });
306        }
307
308        Ok(())
309    }
310
311    pub async fn flush(&self) {
312        let execution_reports = match self.buffer.drain() {
313            Ok(res) => res,
314            Err(e) => {
315                tracing::error!("Unable to acquire lock for State in drain_reports: {}", e);
316                Vec::new()
317            }
318        };
319        let size = execution_reports.len();
320
321        if size > 0 {
322            match self.produce_report(execution_reports) {
323                Ok(report) => match self.send_report(report).await {
324                    Ok(_) => tracing::debug!("Reported {} operations", size),
325                    Err(e) => tracing::error!("{}", e),
326                },
327                Err(e) => tracing::error!("{}", e),
328            }
329        }
330    }
331    pub async fn start_flush_interval(&self, token: Option<CancellationToken>) {
332        let mut tokio_interval = tokio::time::interval(self.flush_interval);
333
334        match token {
335            Some(token) => loop {
336                tokio::select! {
337                    _ = tokio_interval.tick() => { self.flush().await; }
338                    _ = token.cancelled() => { println!("Shutting down."); return; }
339                }
340            },
341            None => loop {
342                tokio_interval.tick().await;
343                self.flush().await;
344            },
345        }
346    }
347}