hive_console_sdk/
agent.rs

1use super::graphql::OperationProcessor;
2use graphql_parser::schema::Document;
3use reqwest::header::{HeaderMap, HeaderValue};
4use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
5use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
6use serde::Serialize;
7use std::{
8    collections::{HashMap, VecDeque},
9    sync::{Arc, Mutex},
10    time::Duration,
11};
12use thiserror::Error;
13use tokio_util::sync::CancellationToken;
14
15#[derive(Serialize, Debug)]
16pub struct Report {
17    size: usize,
18    map: HashMap<String, OperationMapRecord>,
19    operations: Vec<Operation>,
20}
21
22#[allow(non_snake_case)]
23#[derive(Serialize, Debug)]
24struct OperationMapRecord {
25    operation: String,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    operationName: Option<String>,
28    fields: Vec<String>,
29}
30
31#[allow(non_snake_case)]
32#[derive(Serialize, Debug)]
33struct Operation {
34    operationMapKey: String,
35    timestamp: u64,
36    execution: Execution,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    metadata: Option<Metadata>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    persistedDocumentHash: Option<String>,
41}
42
43#[allow(non_snake_case)]
44#[derive(Serialize, Debug)]
45struct Execution {
46    ok: bool,
47    duration: u128,
48    errorsTotal: usize,
49}
50
51#[derive(Serialize, Debug)]
52struct Metadata {
53    #[serde(skip_serializing_if = "Option::is_none")]
54    client: Option<ClientInfo>,
55}
56
57#[derive(Serialize, Debug)]
58struct ClientInfo {
59    #[serde(skip_serializing_if = "Option::is_none")]
60    name: Option<String>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    version: Option<String>,
63}
64
65#[derive(Debug, Clone)]
66pub struct ExecutionReport {
67    pub schema: Arc<Document<'static, String>>,
68    pub client_name: Option<String>,
69    pub client_version: Option<String>,
70    pub timestamp: u64,
71    pub duration: Duration,
72    pub ok: bool,
73    pub errors: usize,
74    pub operation_body: String,
75    pub operation_name: Option<String>,
76    pub persisted_document_hash: Option<String>,
77}
78
79#[derive(Debug, Default)]
80pub struct Buffer(Mutex<VecDeque<ExecutionReport>>);
81
82impl Buffer {
83    fn new() -> Self {
84        Self(Mutex::new(VecDeque::new()))
85    }
86
87    fn lock_buffer(
88        &self,
89    ) -> Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> {
90        let buffer: Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> =
91            self.0.lock().map_err(|e| AgentError::Lock(e.to_string()));
92        buffer
93    }
94
95    pub fn push(&self, report: ExecutionReport) -> Result<usize, AgentError> {
96        let mut buffer = self.lock_buffer()?;
97        buffer.push_back(report);
98        Ok(buffer.len())
99    }
100
101    pub fn drain(&self) -> Result<Vec<ExecutionReport>, AgentError> {
102        let mut buffer = self.lock_buffer()?;
103        let reports: Vec<ExecutionReport> = buffer.drain(..).collect();
104        Ok(reports)
105    }
106}
107pub struct UsageAgent {
108    buffer_size: usize,
109    endpoint: String,
110    buffer: Buffer,
111    processor: OperationProcessor,
112    client: ClientWithMiddleware,
113    flush_interval: Duration,
114}
115
116fn non_empty_string(value: Option<String>) -> Option<String> {
117    value.filter(|str| str.is_empty())
118}
119
120#[derive(Error, Debug)]
121pub enum AgentError {
122    #[error("unable to acquire lock: {0}")]
123    Lock(String),
124    #[error("unable to send report: token is missing")]
125    Unauthorized,
126    #[error("unable to send report: no access")]
127    Forbidden,
128    #[error("unable to send report: rate limited")]
129    RateLimited,
130    #[error("invalid token provided: {0}")]
131    InvalidToken(String),
132    #[error("unable to instantiate the http client for reports sending: {0}")]
133    HTTPClientCreationError(reqwest::Error),
134    #[error("unable to send report: {0}")]
135    Unknown(String),
136}
137
138impl UsageAgent {
139    #[allow(clippy::too_many_arguments)]
140    pub fn try_new(
141        token: &str,
142        endpoint: String,
143        target_id: Option<String>,
144        buffer_size: usize,
145        connect_timeout: Duration,
146        request_timeout: Duration,
147        accept_invalid_certs: bool,
148        flush_interval: Duration,
149        user_agent: String,
150    ) -> Result<Arc<Self>, AgentError> {
151        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
152
153        let mut default_headers = HeaderMap::new();
154
155        default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2"));
156
157        let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token))
158            .map_err(|_| AgentError::InvalidToken(token.to_string()))?;
159
160        authorization_header.set_sensitive(true);
161
162        default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header);
163
164        default_headers.insert(
165            reqwest::header::CONTENT_TYPE,
166            HeaderValue::from_static("application/json"),
167        );
168
169        let reqwest_agent = reqwest::Client::builder()
170            .danger_accept_invalid_certs(accept_invalid_certs)
171            .connect_timeout(connect_timeout)
172            .timeout(request_timeout)
173            .user_agent(user_agent)
174            .build()
175            .map_err(AgentError::HTTPClientCreationError)?;
176        let client = ClientBuilder::new(reqwest_agent)
177            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
178            .build();
179
180        let mut endpoint = endpoint;
181
182        if token.starts_with("hvo1/") || token.starts_with("hvu1/") || token.starts_with("hvp1/") {
183            if let Some(target_id) = target_id {
184                endpoint.push_str(&format!("/{}", target_id));
185            }
186        }
187
188        Ok(Arc::new(Self {
189            buffer_size,
190            endpoint,
191            buffer: Buffer::new(),
192            processor: OperationProcessor::new(),
193            client,
194            flush_interval,
195        }))
196    }
197
198    fn produce_report(&self, reports: Vec<ExecutionReport>) -> Result<Report, AgentError> {
199        let mut report = Report {
200            size: 0,
201            map: HashMap::new(),
202            operations: Vec::new(),
203        };
204
205        // iterate over reports and check if they are valid
206        for op in reports {
207            let operation = self.processor.process(&op.operation_body, &op.schema);
208            match operation {
209                Err(e) => {
210                    tracing::warn!(
211                        "Dropping operation \"{}\" (phase: PROCESSING): {}",
212                        op.operation_name
213                            .clone()
214                            .or_else(|| Some("anonymous".to_string()))
215                            .unwrap(),
216                        e
217                    );
218                    continue;
219                }
220                Ok(operation) => match operation {
221                    Some(operation) => {
222                        let hash = operation.hash;
223
224                        let client_name = non_empty_string(op.client_name);
225                        let client_version = non_empty_string(op.client_version);
226
227                        let metadata: Option<Metadata> =
228                            if client_name.is_some() || client_version.is_some() {
229                                Some(Metadata {
230                                    client: Some(ClientInfo {
231                                        name: client_name,
232                                        version: client_version,
233                                    }),
234                                })
235                            } else {
236                                None
237                            };
238                        report.operations.push(Operation {
239                            operationMapKey: hash.clone(),
240                            timestamp: op.timestamp,
241                            execution: Execution {
242                                ok: op.ok,
243                                duration: op.duration.as_nanos(),
244                                errorsTotal: op.errors,
245                            },
246                            persistedDocumentHash: op.persisted_document_hash,
247                            metadata,
248                        });
249                        if let std::collections::hash_map::Entry::Vacant(e) = report.map.entry(hash)
250                        {
251                            e.insert(OperationMapRecord {
252                                operation: operation.operation,
253                                operationName: non_empty_string(op.operation_name),
254                                fields: operation.coordinates,
255                            });
256                        }
257                        report.size += 1;
258                    }
259                    None => {
260                        tracing::debug!(
261                            "Dropping operation (phase: PROCESSING): probably introspection query"
262                        );
263                    }
264                },
265            }
266        }
267
268        Ok(report)
269    }
270
271    pub async fn send_report(&self, report: Report) -> Result<(), AgentError> {
272        let report_body =
273            serde_json::to_vec(&report).map_err(|e| AgentError::Unknown(e.to_string()))?;
274        // Based on https://the-guild.dev/graphql/hive/docs/specs/usage-reports#data-structure
275        let resp = self
276            .client
277            .post(&self.endpoint)
278            .header(reqwest::header::CONTENT_LENGTH, report_body.len())
279            .body(report_body)
280            .send()
281            .await
282            .map_err(|e| AgentError::Unknown(e.to_string()))?;
283
284        match resp.status() {
285            reqwest::StatusCode::OK => Ok(()),
286            reqwest::StatusCode::UNAUTHORIZED => Err(AgentError::Unauthorized),
287            reqwest::StatusCode::FORBIDDEN => Err(AgentError::Forbidden),
288            reqwest::StatusCode::TOO_MANY_REQUESTS => Err(AgentError::RateLimited),
289            _ => Err(AgentError::Unknown(format!(
290                "({}) {}",
291                resp.status(),
292                resp.text().await.unwrap_or_default()
293            ))),
294        }
295    }
296
297    pub async fn flush(&self) {
298        let execution_reports = match self.buffer.drain() {
299            Ok(res) => res,
300            Err(e) => {
301                tracing::error!("Unable to acquire lock for State in drain_reports: {}", e);
302                Vec::new()
303            }
304        };
305        let size = execution_reports.len();
306
307        if size > 0 {
308            match self.produce_report(execution_reports) {
309                Ok(report) => match self.send_report(report).await {
310                    Ok(_) => tracing::debug!("Reported {} operations", size),
311                    Err(e) => tracing::error!("{}", e),
312                },
313                Err(e) => tracing::error!("{}", e),
314            }
315        }
316    }
317    pub async fn start_flush_interval(&self, token: Option<CancellationToken>) {
318        let mut tokio_interval = tokio::time::interval(self.flush_interval);
319
320        match token {
321            Some(token) => loop {
322                tokio::select! {
323                    _ = tokio_interval.tick() => { self.flush().await; }
324                    _ = token.cancelled() => { println!("Shutting down."); return; }
325                }
326            },
327            None => loop {
328                tokio_interval.tick().await;
329                self.flush().await;
330            },
331        }
332    }
333}
334
335pub trait UsageAgentExt {
336    fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError>;
337    fn flush_if_full(&self, size: usize) -> Result<(), AgentError>;
338}
339
340impl UsageAgentExt for Arc<UsageAgent> {
341    fn flush_if_full(&self, size: usize) -> Result<(), AgentError> {
342        if size >= self.buffer_size {
343            let cloned_self = self.clone();
344            tokio::task::spawn(async move {
345                cloned_self.flush().await;
346            });
347        }
348
349        Ok(())
350    }
351
352    fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> {
353        let size = self.buffer.push(execution_report)?;
354
355        self.flush_if_full(size)?;
356
357        Ok(())
358    }
359}