1use super::graphql::OperationProcessor;
2use graphql_parser::schema::Document;
3use reqwest::header::{HeaderMap, HeaderValue};
4use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
5use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
6use serde::Serialize;
7use std::{
8 collections::{HashMap, VecDeque},
9 sync::{Arc, Mutex},
10 time::Duration,
11};
12use thiserror::Error;
13use tokio_util::sync::CancellationToken;
14
15#[derive(Serialize, Debug)]
16pub struct Report {
17 size: usize,
18 map: HashMap<String, OperationMapRecord>,
19 operations: Vec<Operation>,
20}
21
22#[allow(non_snake_case)]
23#[derive(Serialize, Debug)]
24struct OperationMapRecord {
25 operation: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 operationName: Option<String>,
28 fields: Vec<String>,
29}
30
31#[allow(non_snake_case)]
32#[derive(Serialize, Debug)]
33struct Operation {
34 operationMapKey: String,
35 timestamp: u64,
36 execution: Execution,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 metadata: Option<Metadata>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 persistedDocumentHash: Option<String>,
41}
42
43#[allow(non_snake_case)]
44#[derive(Serialize, Debug)]
45struct Execution {
46 ok: bool,
47 duration: u128,
48 errorsTotal: usize,
49}
50
51#[derive(Serialize, Debug)]
52struct Metadata {
53 #[serde(skip_serializing_if = "Option::is_none")]
54 client: Option<ClientInfo>,
55}
56
57#[derive(Serialize, Debug)]
58struct ClientInfo {
59 #[serde(skip_serializing_if = "Option::is_none")]
60 name: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 version: Option<String>,
63}
64
65#[derive(Debug, Clone)]
66pub struct ExecutionReport {
67 pub schema: Arc<Document<'static, String>>,
68 pub client_name: Option<String>,
69 pub client_version: Option<String>,
70 pub timestamp: u64,
71 pub duration: Duration,
72 pub ok: bool,
73 pub errors: usize,
74 pub operation_body: String,
75 pub operation_name: Option<String>,
76 pub persisted_document_hash: Option<String>,
77}
78
79#[derive(Debug, Default)]
80pub struct Buffer(Mutex<VecDeque<ExecutionReport>>);
81
82impl Buffer {
83 fn new() -> Self {
84 Self(Mutex::new(VecDeque::new()))
85 }
86
87 fn lock_buffer(
88 &self,
89 ) -> Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> {
90 let buffer: Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> =
91 self.0.lock().map_err(|e| AgentError::Lock(e.to_string()));
92 buffer
93 }
94
95 pub fn push(&self, report: ExecutionReport) -> Result<usize, AgentError> {
96 let mut buffer = self.lock_buffer()?;
97 buffer.push_back(report);
98 Ok(buffer.len())
99 }
100
101 pub fn drain(&self) -> Result<Vec<ExecutionReport>, AgentError> {
102 let mut buffer = self.lock_buffer()?;
103 let reports: Vec<ExecutionReport> = buffer.drain(..).collect();
104 Ok(reports)
105 }
106}
107pub struct UsageAgent {
108 buffer_size: usize,
109 endpoint: String,
110 buffer: Buffer,
111 processor: OperationProcessor,
112 client: ClientWithMiddleware,
113 flush_interval: Duration,
114}
115
116fn non_empty_string(value: Option<String>) -> Option<String> {
117 value.filter(|str| str.is_empty())
118}
119
120#[derive(Error, Debug)]
121pub enum AgentError {
122 #[error("unable to acquire lock: {0}")]
123 Lock(String),
124 #[error("unable to send report: token is missing")]
125 Unauthorized,
126 #[error("unable to send report: no access")]
127 Forbidden,
128 #[error("unable to send report: rate limited")]
129 RateLimited,
130 #[error("invalid token provided: {0}")]
131 InvalidToken(String),
132 #[error("unable to instantiate the http client for reports sending: {0}")]
133 HTTPClientCreationError(reqwest::Error),
134 #[error("unable to send report: {0}")]
135 Unknown(String),
136}
137
138impl UsageAgent {
139 #[allow(clippy::too_many_arguments)]
140 pub fn try_new(
141 token: &str,
142 endpoint: String,
143 target_id: Option<String>,
144 buffer_size: usize,
145 connect_timeout: Duration,
146 request_timeout: Duration,
147 accept_invalid_certs: bool,
148 flush_interval: Duration,
149 user_agent: String,
150 ) -> Result<Arc<Self>, AgentError> {
151 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
152
153 let mut default_headers = HeaderMap::new();
154
155 default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2"));
156
157 let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token))
158 .map_err(|_| AgentError::InvalidToken(token.to_string()))?;
159
160 authorization_header.set_sensitive(true);
161
162 default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header);
163
164 default_headers.insert(
165 reqwest::header::CONTENT_TYPE,
166 HeaderValue::from_static("application/json"),
167 );
168
169 let reqwest_agent = reqwest::Client::builder()
170 .danger_accept_invalid_certs(accept_invalid_certs)
171 .connect_timeout(connect_timeout)
172 .timeout(request_timeout)
173 .user_agent(user_agent)
174 .build()
175 .map_err(AgentError::HTTPClientCreationError)?;
176 let client = ClientBuilder::new(reqwest_agent)
177 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
178 .build();
179
180 let mut endpoint = endpoint;
181
182 if token.starts_with("hvo1/") || token.starts_with("hvu1/") || token.starts_with("hvp1/") {
183 if let Some(target_id) = target_id {
184 endpoint.push_str(&format!("/{}", target_id));
185 }
186 }
187
188 Ok(Arc::new(Self {
189 buffer_size,
190 endpoint,
191 buffer: Buffer::new(),
192 processor: OperationProcessor::new(),
193 client,
194 flush_interval,
195 }))
196 }
197
198 fn produce_report(&self, reports: Vec<ExecutionReport>) -> Result<Report, AgentError> {
199 let mut report = Report {
200 size: 0,
201 map: HashMap::new(),
202 operations: Vec::new(),
203 };
204
205 for op in reports {
207 let operation = self.processor.process(&op.operation_body, &op.schema);
208 match operation {
209 Err(e) => {
210 tracing::warn!(
211 "Dropping operation \"{}\" (phase: PROCESSING): {}",
212 op.operation_name
213 .clone()
214 .or_else(|| Some("anonymous".to_string()))
215 .unwrap(),
216 e
217 );
218 continue;
219 }
220 Ok(operation) => match operation {
221 Some(operation) => {
222 let hash = operation.hash;
223
224 let client_name = non_empty_string(op.client_name);
225 let client_version = non_empty_string(op.client_version);
226
227 let metadata: Option<Metadata> =
228 if client_name.is_some() || client_version.is_some() {
229 Some(Metadata {
230 client: Some(ClientInfo {
231 name: client_name,
232 version: client_version,
233 }),
234 })
235 } else {
236 None
237 };
238 report.operations.push(Operation {
239 operationMapKey: hash.clone(),
240 timestamp: op.timestamp,
241 execution: Execution {
242 ok: op.ok,
243 duration: op.duration.as_nanos(),
244 errorsTotal: op.errors,
245 },
246 persistedDocumentHash: op.persisted_document_hash,
247 metadata,
248 });
249 if let std::collections::hash_map::Entry::Vacant(e) = report.map.entry(hash)
250 {
251 e.insert(OperationMapRecord {
252 operation: operation.operation,
253 operationName: non_empty_string(op.operation_name),
254 fields: operation.coordinates,
255 });
256 }
257 report.size += 1;
258 }
259 None => {
260 tracing::debug!(
261 "Dropping operation (phase: PROCESSING): probably introspection query"
262 );
263 }
264 },
265 }
266 }
267
268 Ok(report)
269 }
270
271 pub async fn send_report(&self, report: Report) -> Result<(), AgentError> {
272 let report_body =
273 serde_json::to_vec(&report).map_err(|e| AgentError::Unknown(e.to_string()))?;
274 let resp = self
276 .client
277 .post(&self.endpoint)
278 .header(reqwest::header::CONTENT_LENGTH, report_body.len())
279 .body(report_body)
280 .send()
281 .await
282 .map_err(|e| AgentError::Unknown(e.to_string()))?;
283
284 match resp.status() {
285 reqwest::StatusCode::OK => Ok(()),
286 reqwest::StatusCode::UNAUTHORIZED => Err(AgentError::Unauthorized),
287 reqwest::StatusCode::FORBIDDEN => Err(AgentError::Forbidden),
288 reqwest::StatusCode::TOO_MANY_REQUESTS => Err(AgentError::RateLimited),
289 _ => Err(AgentError::Unknown(format!(
290 "({}) {}",
291 resp.status(),
292 resp.text().await.unwrap_or_default()
293 ))),
294 }
295 }
296
297 pub async fn flush(&self) {
298 let execution_reports = match self.buffer.drain() {
299 Ok(res) => res,
300 Err(e) => {
301 tracing::error!("Unable to acquire lock for State in drain_reports: {}", e);
302 Vec::new()
303 }
304 };
305 let size = execution_reports.len();
306
307 if size > 0 {
308 match self.produce_report(execution_reports) {
309 Ok(report) => match self.send_report(report).await {
310 Ok(_) => tracing::debug!("Reported {} operations", size),
311 Err(e) => tracing::error!("{}", e),
312 },
313 Err(e) => tracing::error!("{}", e),
314 }
315 }
316 }
317 pub async fn start_flush_interval(&self, token: Option<CancellationToken>) {
318 let mut tokio_interval = tokio::time::interval(self.flush_interval);
319
320 match token {
321 Some(token) => loop {
322 tokio::select! {
323 _ = tokio_interval.tick() => { self.flush().await; }
324 _ = token.cancelled() => { println!("Shutting down."); return; }
325 }
326 },
327 None => loop {
328 tokio_interval.tick().await;
329 self.flush().await;
330 },
331 }
332 }
333}
334
335pub trait UsageAgentExt {
336 fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError>;
337 fn flush_if_full(&self, size: usize) -> Result<(), AgentError>;
338}
339
340impl UsageAgentExt for Arc<UsageAgent> {
341 fn flush_if_full(&self, size: usize) -> Result<(), AgentError> {
342 if size >= self.buffer_size {
343 let cloned_self = self.clone();
344 tokio::task::spawn(async move {
345 cloned_self.flush().await;
346 });
347 }
348
349 Ok(())
350 }
351
352 fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> {
353 let size = self.buffer.push(execution_report)?;
354
355 self.flush_if_full(size)?;
356
357 Ok(())
358 }
359}