1use super::graphql::OperationProcessor;
2use graphql_parser::schema::Document;
3use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
4use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
5use serde::Serialize;
6use std::{
7 collections::{HashMap, VecDeque},
8 sync::{Arc, Mutex},
9 time::Duration,
10};
11use thiserror::Error;
12use tokio_util::sync::CancellationToken;
13
14#[derive(Serialize, Debug)]
15pub struct Report {
16 size: usize,
17 map: HashMap<String, OperationMapRecord>,
18 operations: Vec<Operation>,
19}
20
21#[allow(non_snake_case)]
22#[derive(Serialize, Debug)]
23struct OperationMapRecord {
24 operation: String,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 operationName: Option<String>,
27 fields: Vec<String>,
28}
29
30#[allow(non_snake_case)]
31#[derive(Serialize, Debug)]
32struct Operation {
33 operationMapKey: String,
34 timestamp: u64,
35 execution: Execution,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 metadata: Option<Metadata>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 persistedDocumentHash: Option<String>,
40}
41
42#[allow(non_snake_case)]
43#[derive(Serialize, Debug)]
44struct Execution {
45 ok: bool,
46 duration: u128,
47 errorsTotal: usize,
48}
49
50#[derive(Serialize, Debug)]
51struct Metadata {
52 #[serde(skip_serializing_if = "Option::is_none")]
53 client: Option<ClientInfo>,
54}
55
56#[derive(Serialize, Debug)]
57struct ClientInfo {
58 #[serde(skip_serializing_if = "Option::is_none")]
59 name: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 version: Option<String>,
62}
63
64#[derive(Debug, Clone)]
65pub struct ExecutionReport {
66 pub schema: Arc<Document<'static, String>>,
67 pub client_name: Option<String>,
68 pub client_version: Option<String>,
69 pub timestamp: u64,
70 pub duration: Duration,
71 pub ok: bool,
72 pub errors: usize,
73 pub operation_body: String,
74 pub operation_name: Option<String>,
75 pub persisted_document_hash: Option<String>,
76}
77
78#[derive(Debug, Default)]
79pub struct Buffer(Mutex<VecDeque<ExecutionReport>>);
80
81impl Buffer {
82 fn new() -> Self {
83 Self(Mutex::new(VecDeque::new()))
84 }
85
86 fn lock_buffer(
87 &self,
88 ) -> Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> {
89 let buffer: Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> =
90 self.0.lock().map_err(|e| AgentError::Lock(e.to_string()));
91 buffer
92 }
93
94 pub fn push(&self, report: ExecutionReport) -> Result<usize, AgentError> {
95 let mut buffer = self.lock_buffer()?;
96 buffer.push_back(report);
97 Ok(buffer.len())
98 }
99
100 pub fn drain(&self) -> Result<Vec<ExecutionReport>, AgentError> {
101 let mut buffer = self.lock_buffer()?;
102 let reports: Vec<ExecutionReport> = buffer.drain(..).collect();
103 Ok(reports)
104 }
105}
106
107#[derive(Clone)]
108pub struct UsageAgent {
109 token: String,
110 buffer_size: usize,
111 endpoint: String,
112 buffer: Arc<Buffer>,
113 processor: Arc<OperationProcessor>,
114 client: ClientWithMiddleware,
115 flush_interval: Duration,
116 user_agent: String,
117}
118
119fn non_empty_string(value: Option<String>) -> Option<String> {
120 value.filter(|str| str.is_empty())
121}
122
123#[derive(Error, Debug)]
124pub enum AgentError {
125 #[error("unable to acquire lock: {0}")]
126 Lock(String),
127 #[error("unable to send report: token is missing")]
128 Unauthorized,
129 #[error("unable to send report: no access")]
130 Forbidden,
131 #[error("unable to send report: rate limited")]
132 RateLimited,
133 #[error("unable to send report: {0}")]
134 Unknown(String),
135}
136
137impl UsageAgent {
138 #[allow(clippy::too_many_arguments)]
139 pub fn new(
140 token: String,
141 endpoint: String,
142 target_id: Option<String>,
143 buffer_size: usize,
144 connect_timeout: Duration,
145 request_timeout: Duration,
146 accept_invalid_certs: bool,
147 flush_interval: Duration,
148 user_agent: String,
149 ) -> Self {
150 let processor = Arc::new(OperationProcessor::new());
151
152 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
153
154 let reqwest_agent = reqwest::Client::builder()
155 .danger_accept_invalid_certs(accept_invalid_certs)
156 .connect_timeout(connect_timeout)
157 .timeout(request_timeout)
158 .build()
159 .map_err(|err| err.to_string())
160 .expect("Couldn't instantiate the http client for reports sending!");
161 let client = ClientBuilder::new(reqwest_agent)
162 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
163 .build();
164 let buffer = Arc::new(Buffer::new());
165
166 let mut endpoint = endpoint;
167
168 if token.starts_with("hvo1/") {
169 if let Some(target_id) = target_id {
170 endpoint.push_str(&format!("/{}", target_id));
171 }
172 }
173
174 UsageAgent {
175 buffer,
176 processor,
177 endpoint,
178 token,
179 buffer_size,
180 client,
181 flush_interval,
182 user_agent,
183 }
184 }
185
186 fn produce_report(&self, reports: Vec<ExecutionReport>) -> Result<Report, AgentError> {
187 let mut report = Report {
188 size: 0,
189 map: HashMap::new(),
190 operations: Vec::new(),
191 };
192
193 for op in reports {
195 let operation = self.processor.process(&op.operation_body, &op.schema);
196 match operation {
197 Err(e) => {
198 tracing::warn!(
199 "Dropping operation \"{}\" (phase: PROCESSING): {}",
200 op.operation_name
201 .clone()
202 .or_else(|| Some("anonymous".to_string()))
203 .unwrap(),
204 e
205 );
206 continue;
207 }
208 Ok(operation) => match operation {
209 Some(operation) => {
210 let hash = operation.hash;
211
212 let client_name = non_empty_string(op.client_name);
213 let client_version = non_empty_string(op.client_version);
214
215 let metadata: Option<Metadata> =
216 if client_name.is_some() || client_version.is_some() {
217 Some(Metadata {
218 client: Some(ClientInfo {
219 name: client_name,
220 version: client_version,
221 }),
222 })
223 } else {
224 None
225 };
226 report.operations.push(Operation {
227 operationMapKey: hash.clone(),
228 timestamp: op.timestamp,
229 execution: Execution {
230 ok: op.ok,
231 duration: op.duration.as_nanos(),
232 errorsTotal: op.errors,
233 },
234 persistedDocumentHash: op.persisted_document_hash,
235 metadata,
236 });
237 if let std::collections::hash_map::Entry::Vacant(e) = report.map.entry(hash)
238 {
239 e.insert(OperationMapRecord {
240 operation: operation.operation,
241 operationName: non_empty_string(op.operation_name),
242 fields: operation.coordinates,
243 });
244 }
245 report.size += 1;
246 }
247 None => {
248 tracing::debug!(
249 "Dropping operation (phase: PROCESSING): probably introspection query"
250 );
251 }
252 },
253 }
254 }
255
256 Ok(report)
257 }
258
259 pub fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> {
260 let size = self.buffer.push(execution_report)?;
261
262 self.flush_if_full(size)?;
263
264 Ok(())
265 }
266
267 pub async fn send_report(&self, report: Report) -> Result<(), AgentError> {
268 let report_body =
269 serde_json::to_vec(&report).map_err(|e| AgentError::Unknown(e.to_string()))?;
270 let resp = self
272 .client
273 .post(&self.endpoint)
274 .header("X-Usage-API-Version", "2")
275 .header(
276 reqwest::header::AUTHORIZATION,
277 format!("Bearer {}", self.token),
278 )
279 .header(reqwest::header::USER_AGENT, self.user_agent.to_string())
280 .header(reqwest::header::CONTENT_TYPE, "application/json")
281 .header(reqwest::header::CONTENT_LENGTH, report_body.len())
282 .body(report_body)
283 .send()
284 .await
285 .map_err(|e| AgentError::Unknown(e.to_string()))?;
286
287 match resp.status() {
288 reqwest::StatusCode::OK => Ok(()),
289 reqwest::StatusCode::UNAUTHORIZED => Err(AgentError::Unauthorized),
290 reqwest::StatusCode::FORBIDDEN => Err(AgentError::Forbidden),
291 reqwest::StatusCode::TOO_MANY_REQUESTS => Err(AgentError::RateLimited),
292 _ => Err(AgentError::Unknown(format!(
293 "({}) {}",
294 resp.status(),
295 resp.text().await.unwrap_or_default()
296 ))),
297 }
298 }
299
300 pub fn flush_if_full(&self, size: usize) -> Result<(), AgentError> {
301 if size >= self.buffer_size {
302 let cloned_self = self.clone();
303 tokio::task::spawn(async move {
304 cloned_self.flush().await;
305 });
306 }
307
308 Ok(())
309 }
310
311 pub async fn flush(&self) {
312 let execution_reports = match self.buffer.drain() {
313 Ok(res) => res,
314 Err(e) => {
315 tracing::error!("Unable to acquire lock for State in drain_reports: {}", e);
316 Vec::new()
317 }
318 };
319 let size = execution_reports.len();
320
321 if size > 0 {
322 match self.produce_report(execution_reports) {
323 Ok(report) => match self.send_report(report).await {
324 Ok(_) => tracing::debug!("Reported {} operations", size),
325 Err(e) => tracing::error!("{}", e),
326 },
327 Err(e) => tracing::error!("{}", e),
328 }
329 }
330 }
331 pub async fn start_flush_interval(&self, token: Option<CancellationToken>) {
332 let mut tokio_interval = tokio::time::interval(self.flush_interval);
333
334 match token {
335 Some(token) => loop {
336 tokio::select! {
337 _ = tokio_interval.tick() => { self.flush().await; }
338 _ = token.cancelled() => { println!("Shutting down."); return; }
339 }
340 },
341 None => loop {
342 tokio_interval.tick().await;
343 self.flush().await;
344 },
345 }
346 }
347}