hive_apollo_router_plugin/
usage.rs1use crate::consts::PLUGIN_VERSION;
2use apollo_router::layers::ServiceBuilderExt;
3use apollo_router::plugin::Plugin;
4use apollo_router::plugin::PluginInit;
5use apollo_router::services::*;
6use apollo_router::Context;
7use core::ops::Drop;
8use futures::StreamExt;
9use hive_console_sdk::agent::usage_agent::UsageAgentExt;
10use hive_console_sdk::agent::usage_agent::{ExecutionReport, UsageAgent};
11use hive_console_sdk::graphql_tools::parser::parse_schema;
12use hive_console_sdk::graphql_tools::parser::schema::Document;
13use http::HeaderValue;
14use rand::Rng;
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use std::collections::HashSet;
18use std::env;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use std::time::{SystemTime, UNIX_EPOCH};
22use tokio_util::sync::CancellationToken;
23use tower::BoxError;
24use tower::ServiceBuilder;
25use tower::ServiceExt;
26
27use crate::persisted_documents::PERSISTED_DOCUMENT_HASH_KEY;
28
29pub(crate) static OPERATION_CONTEXT: &str = "hive::operation_context";
30
31#[derive(Serialize, Deserialize, Debug)]
32struct OperationContext {
33 pub(crate) client_name: Option<String>,
34 pub(crate) client_version: Option<String>,
35 pub(crate) timestamp: u64,
36 pub(crate) operation_body: String,
37 pub(crate) operation_name: Option<String>,
38 pub(crate) dropped: bool,
39}
40
41#[derive(Clone, Debug)]
42struct OperationConfig {
43 sample_rate: f64,
44 exclude: Option<Vec<String>>,
45 client_name_header: String,
46 client_version_header: String,
47}
48
49pub struct UsagePlugin {
50 config: OperationConfig,
51 agent: Option<UsageAgent>,
52 schema: Arc<Document<'static, String>>,
53 cancellation_token: Arc<CancellationToken>,
54}
55
56#[derive(Clone, Debug, Deserialize, JsonSchema, Default)]
57pub struct Config {
58 enabled: Option<bool>,
60 registry_token: Option<String>,
63 registry_usage_endpoint: Option<String>,
68 target: Option<String>,
72 sample_rate: Option<f64>,
77 exclude: Option<Vec<String>>,
79 client_name_header: Option<String>,
80 client_version_header: Option<String>,
81 buffer_size: Option<usize>,
84 connect_timeout: Option<u64>,
88 request_timeout: Option<u64>,
92 accept_invalid_certs: Option<bool>,
95 flush_interval: Option<u64>,
98}
99
100impl UsagePlugin {
101 fn populate_context(config: OperationConfig, req: &supergraph::Request) {
102 let context = &req.context;
103 let http_request = &req.supergraph_request;
104 let headers = http_request.headers();
105
106 let get_header_value = |key: &str| {
107 headers
108 .get(key)
109 .cloned()
110 .unwrap_or_else(|| HeaderValue::from_static(""))
111 .to_str()
112 .ok()
113 .map(|v| v.to_string())
114 };
115
116 let client_name = get_header_value(&config.client_name_header);
117 let client_version = get_header_value(&config.client_version_header);
118
119 let operation_name = req.supergraph_request.body().operation_name.clone();
120 let operation_body = req
121 .supergraph_request
122 .body()
123 .query
124 .clone()
125 .expect("operation body should not be empty");
126
127 let excluded_operation_names: HashSet<String> = config
128 .exclude
129 .unwrap_or_default()
130 .clone()
131 .into_iter()
132 .collect();
133
134 let mut rng = rand::rng();
135 let sampled = rng.random::<f64>() < config.sample_rate;
136 let mut dropped = !sampled;
137
138 if !dropped {
139 if let Some(name) = &operation_name {
140 if excluded_operation_names.contains(name) {
141 dropped = true;
142 }
143 }
144 }
145
146 let _ = context.insert(
147 OPERATION_CONTEXT,
148 OperationContext {
149 dropped,
150 client_name,
151 client_version,
152 operation_name,
153 operation_body,
154 timestamp: SystemTime::now()
155 .duration_since(UNIX_EPOCH)
156 .unwrap()
157 .as_secs()
158 * 1000,
159 },
160 );
161 }
162}
163
164#[async_trait::async_trait]
165impl Plugin for UsagePlugin {
166 type Config = Config;
167
168 async fn new(init: PluginInit<Config>) -> Result<Self, BoxError> {
169 let user_config = init.config;
170
171 let enabled = user_config.enabled.unwrap_or(true);
172
173 if enabled {
174 tracing::info!("Starting GraphQL Hive Usage plugin");
175 }
176
177 let cancellation_token = Arc::new(CancellationToken::new());
178
179 let agent = if enabled {
180 let mut agent =
181 UsageAgent::builder().user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION));
182
183 if let Some(endpoint) = user_config.registry_usage_endpoint {
184 agent = agent.endpoint(endpoint);
185 } else if let Ok(env_endpoint) = env::var("HIVE_ENDPOINT") {
186 agent = agent.endpoint(env_endpoint);
187 }
188
189 if let Some(token) = user_config.registry_token {
190 agent = agent.token(token);
191 } else if let Ok(env_token) = env::var("HIVE_TOKEN") {
192 agent = agent.token(env_token);
193 }
194
195 if let Some(target_id) = user_config.target {
196 agent = agent.target_id(target_id);
197 } else if let Ok(env_target) = env::var("HIVE_TARGET_ID") {
198 agent = agent.target_id(env_target);
199 }
200
201 if let Some(buffer_size) = user_config.buffer_size {
202 agent = agent.buffer_size(buffer_size);
203 }
204
205 if let Some(connect_timeout) = user_config.connect_timeout {
206 agent = agent.connect_timeout(Duration::from_secs(connect_timeout));
207 }
208
209 if let Some(request_timeout) = user_config.request_timeout {
210 agent = agent.request_timeout(Duration::from_secs(request_timeout));
211 }
212
213 if let Some(accept_invalid_certs) = user_config.accept_invalid_certs {
214 agent = agent.accept_invalid_certs(accept_invalid_certs);
215 }
216
217 if let Some(flush_interval) = user_config.flush_interval {
218 agent = agent.flush_interval(Duration::from_secs(flush_interval));
219 }
220
221 let agent = agent.build().map_err(Box::new)?;
222
223 let cancellation_token_for_interval = cancellation_token.clone();
224 let agent_for_interval = agent.clone();
225 tokio::task::spawn(async move {
226 agent_for_interval
227 .start_flush_interval(&cancellation_token_for_interval)
228 .await;
229 });
230 Some(agent)
231 } else {
232 None
233 };
234
235 let schema = parse_schema(&init.supergraph_sdl)
236 .expect("Failed to parse schema")
237 .into_static();
238
239 Ok(UsagePlugin {
240 schema: Arc::new(schema),
241 config: OperationConfig {
242 sample_rate: user_config.sample_rate.unwrap_or(1.0),
243 exclude: user_config.exclude,
244 client_name_header: user_config
245 .client_name_header
246 .unwrap_or("graphql-client-name".to_string()),
247 client_version_header: user_config
248 .client_version_header
249 .unwrap_or("graphql-client-version".to_string()),
250 },
251 agent,
252 cancellation_token,
253 })
254 }
255
256 fn supergraph_service(&self, service: supergraph::BoxService) -> supergraph::BoxService {
257 let config = self.config.clone();
258 let schema = self.schema.clone();
259 match self.agent.clone() {
260 None => ServiceBuilder::new().service(service).boxed(),
261 Some(agent) => {
262 ServiceBuilder::new()
263 .map_future_with_request_data(
264 move |req: &supergraph::Request| {
265 Self::populate_context(config.clone(), req);
266 req.context.clone()
267 },
268 move |ctx: Context, fut| {
269 let agent = agent.clone();
270 let schema = schema.clone();
271 async move {
272 let start: Instant = Instant::now();
273
274 let operation_context = ctx
276 .get::<_, OperationContext>(OPERATION_CONTEXT)
277 .unwrap_or_default()
278 .unwrap();
279
280 let persisted_document_hash = ctx
283 .get::<_, String>(PERSISTED_DOCUMENT_HASH_KEY)
284 .ok()
285 .unwrap();
286
287 let result: supergraph::ServiceResult = fut.await;
288
289 if operation_context.dropped {
290 tracing::debug!(
291 "Dropping operation (phase: SAMPLING): {}",
292 operation_context
293 .operation_name
294 .clone()
295 .or_else(|| Some("anonymous".to_string()))
296 .unwrap()
297 );
298 return result;
299 }
300
301 let OperationContext {
302 client_name,
303 client_version,
304 operation_name,
305 timestamp,
306 operation_body,
307 ..
308 } = operation_context;
309
310 let duration = start.elapsed();
311
312 match result {
313 Err(e) => {
314 tokio::spawn(async move {
315 let res = agent
316 .add_report(ExecutionReport {
317 schema,
318 client_name,
319 client_version,
320 timestamp,
321 duration,
322 ok: false,
323 errors: 1,
324 operation_body,
325 operation_name,
326 persisted_document_hash,
327 })
328 .await;
329 if let Err(e) = res {
330 tracing::error!("Error adding report: {}", e);
331 }
332 });
333 Err(e)
334 }
335 Ok(router_response) => {
336 let is_failure =
337 !router_response.response.status().is_success();
338 Ok(router_response.map(move |response_stream| {
339 let res = response_stream
340 .map(move |response| {
341 let response_has_errors =
343 !response.errors.is_empty();
344 let agent = agent.clone();
345 let execution_report = ExecutionReport {
346 schema: schema.clone(),
347 client_name: client_name.clone(),
348 client_version: client_version.clone(),
349 timestamp,
350 duration,
351 ok: !is_failure && !response_has_errors,
352 errors: response.errors.len(),
353 operation_body: operation_body.clone(),
354 operation_name: operation_name.clone(),
355 persisted_document_hash:
356 persisted_document_hash.clone(),
357 };
358 tokio::spawn(async move {
359 let res = agent
360 .add_report(execution_report)
361 .await;
362 if let Err(e) = res {
363 tracing::error!(
364 "Error adding report: {}",
365 e
366 );
367 }
368 });
369
370 response
371 })
372 .boxed();
373
374 res
375 }))
376 }
377 }
378 }
379 },
380 )
381 .service(service)
382 .boxed()
383 }
384 }
385 }
386}
387
388impl Drop for UsagePlugin {
389 fn drop(&mut self) {
390 self.cancellation_token.cancel();
391 }
393}
394
395#[cfg(test)]
396mod hive_usage_tests {
397 use apollo_router::{
398 plugin::{test::MockSupergraphService, Plugin, PluginInit},
399 services::supergraph,
400 };
401 use http::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
402 use httpmock::{Method::POST, Mock, MockServer};
403 use jsonschema::Validator;
404 use serde_json::json;
405 use tower::ServiceExt;
406
407 use crate::consts::PLUGIN_VERSION;
408
409 use super::{Config, UsagePlugin};
410
411 lazy_static::lazy_static! {
412 static ref SCHEMA_VALIDATOR: Validator =
413 jsonschema::validator_for(&serde_json::from_str(&std::fs::read_to_string("../../services/usage/usage-report-v2.schema.json").expect("can't load json schema file")).expect("failed to parse json schema")).expect("failed to parse schema");
414 }
415
416 struct UsageTestHelper {
417 mocked_upstream: MockServer,
418 plugin: UsagePlugin,
419 }
420
421 impl UsageTestHelper {
422 async fn new() -> Self {
423 let server: MockServer = MockServer::start();
424 let usage_endpoint = server.url("/usage");
425 let mut config = Config::default();
426 config.enabled = Some(true);
427 config.registry_usage_endpoint = Some(usage_endpoint.to_string());
428 config.registry_token = Some("123".into());
429 config.buffer_size = Some(1);
430 config.flush_interval = Some(1);
431
432 let plugin_service = UsagePlugin::new(
433 PluginInit::fake_builder()
434 .config(config)
435 .supergraph_sdl("type Query { dummy: String! }".to_string().into())
436 .build(),
437 )
438 .await
439 .expect("failed to init plugin");
440
441 UsageTestHelper {
442 mocked_upstream: server,
443 plugin: plugin_service,
444 }
445 }
446
447 fn wait_for_processing(&self) -> tokio::time::Sleep {
448 tokio::time::sleep(tokio::time::Duration::from_secs(2))
449 }
450
451 fn activate_usage_mock(&'_ self) -> Mock<'_> {
452 self.mocked_upstream.mock(|when, then| {
453 when.method(POST)
454 .path("/usage")
455 .header(CONTENT_TYPE.as_str(), "application/json")
456 .header(
457 USER_AGENT.as_str(),
458 format!("hive-apollo-router/{}", PLUGIN_VERSION),
459 )
460 .header(AUTHORIZATION.as_str(), "Bearer 123")
461 .header("X-Usage-API-Version", "2")
462 .matches(|r| {
463 let body = r.body.as_ref().unwrap();
468 let body = String::from_utf8(body.to_vec()).unwrap();
469 let body = serde_json::from_str(&body).unwrap();
470
471 SCHEMA_VALIDATOR.is_valid(&body)
472 });
473 then.status(200);
474 })
475 }
476
477 async fn execute_operation(&self, req: supergraph::Request) -> supergraph::Response {
478 let mut supergraph_service_mock = MockSupergraphService::new();
479
480 supergraph_service_mock
481 .expect_call()
482 .times(1)
483 .returning(move |_| {
484 Ok(supergraph::Response::fake_builder()
485 .data(json!({
486 "data": { "hello": "world" },
487 }))
488 .build()
489 .unwrap())
490 });
491
492 let tower_service = self
493 .plugin
494 .supergraph_service(supergraph_service_mock.boxed());
495
496 let response = tower_service
497 .oneshot(req)
498 .await
499 .expect("failed to execute operation");
500
501 response
502 }
503 }
504
505 #[tokio::test]
506 async fn should_work_correctly_for_simple_query() {
507 let instance = UsageTestHelper::new().await;
508 let req = supergraph::Request::fake_builder()
509 .query("query test { hello }")
510 .operation_name("test")
511 .build()
512 .unwrap();
513 let mock = instance.activate_usage_mock();
514
515 instance.execute_operation(req).await.next_response().await;
516
517 instance.wait_for_processing().await;
518
519 mock.assert();
520 mock.assert_hits(1);
521 }
522
523 #[tokio::test]
524 async fn without_operation_name() {
525 let instance = UsageTestHelper::new().await;
526 let req = supergraph::Request::fake_builder()
527 .query("query { hello }")
528 .build()
529 .unwrap();
530 let mock = instance.activate_usage_mock();
531
532 instance.execute_operation(req).await.next_response().await;
533
534 instance.wait_for_processing().await;
535
536 mock.assert();
537 mock.assert_hits(1);
538 }
539
540 #[tokio::test]
541 async fn multiple_operations() {
542 let instance = UsageTestHelper::new().await;
543 let req = supergraph::Request::fake_builder()
544 .query("query test { hello } query test2 { hello }")
545 .operation_name("test")
546 .build()
547 .unwrap();
548 let mock = instance.activate_usage_mock();
549
550 instance.execute_operation(req).await.next_response().await;
551
552 instance.wait_for_processing().await;
553 println!("Waiting done");
554
555 mock.assert();
556 mock.assert_hits(1);
557 }
558}