Skip to main content

hive_apollo_router_plugin/
usage.rs

1use crate::consts::PLUGIN_VERSION;
2use apollo_router::layers::ServiceBuilderExt;
3use apollo_router::plugin::Plugin;
4use apollo_router::plugin::PluginInit;
5use apollo_router::services::*;
6use apollo_router::Context;
7use core::ops::Drop;
8use futures::StreamExt;
9use hive_console_sdk::agent::usage_agent::UsageAgentExt;
10use hive_console_sdk::agent::usage_agent::{ExecutionReport, UsageAgent};
11use hive_console_sdk::graphql_tools::parser::parse_schema;
12use hive_console_sdk::graphql_tools::parser::schema::Document;
13use http::HeaderValue;
14use rand::Rng;
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use std::collections::HashSet;
18use std::env;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use std::time::{SystemTime, UNIX_EPOCH};
22use tokio_util::sync::CancellationToken;
23use tower::BoxError;
24use tower::ServiceBuilder;
25use tower::ServiceExt;
26
27use crate::persisted_documents::PERSISTED_DOCUMENT_HASH_KEY;
28
29pub(crate) static OPERATION_CONTEXT: &str = "hive::operation_context";
30
31#[derive(Serialize, Deserialize, Debug)]
32struct OperationContext {
33    pub(crate) client_name: Option<String>,
34    pub(crate) client_version: Option<String>,
35    pub(crate) timestamp: u64,
36    pub(crate) operation_body: String,
37    pub(crate) operation_name: Option<String>,
38    pub(crate) dropped: bool,
39}
40
41#[derive(Clone, Debug)]
42struct OperationConfig {
43    sample_rate: f64,
44    exclude: Option<Vec<String>>,
45    client_name_header: String,
46    client_version_header: String,
47}
48
49pub struct UsagePlugin {
50    config: OperationConfig,
51    agent: Option<UsageAgent>,
52    schema: Arc<Document<'static, String>>,
53    cancellation_token: Arc<CancellationToken>,
54}
55
56#[derive(Clone, Debug, Deserialize, JsonSchema, Default)]
57pub struct Config {
58    /// Default: true
59    enabled: Option<bool>,
60    /// Hive token, can also be set using the HIVE_TOKEN environment variable.
61    /// The token can be a registry access token, or a organization access token.
62    registry_token: Option<String>,
63    /// Hive registry token. Set to your `/usage` endpoint if you are self-hosting.
64    /// Default: https://app.graphql-hive.com/usage
65    /// When `target` is set and organization access token is in use, the target ID is appended to the endpoint,
66    /// so usage endpoint becomes `https://app.graphql-hive.com/usage/<target_id>`
67    registry_usage_endpoint: Option<String>,
68    /// The target to which the usage data should be reported to.
69    /// This can either be a slug following the format "$organizationSlug/$projectSlug/$targetSlug" (e.g "the-guild/graphql-hive/staging")
70    /// or an UUID (e.g. "a0f4c605-6541-4350-8cfe-b31f21a4bf80").
71    target: Option<String>,
72    /// Sample rate to determine sampling.
73    /// 0.0 = 0% chance of being sent
74    /// 1.0 = 100% chance of being sent.
75    /// Default: 1.0
76    sample_rate: Option<f64>,
77    /// A list of operations (by name) to be ignored by GraphQL Hive.
78    exclude: Option<Vec<String>>,
79    client_name_header: Option<String>,
80    client_version_header: Option<String>,
81    /// A maximum number of operations to hold in a buffer before sending to GraphQL Hive
82    /// Default: 1000
83    buffer_size: Option<usize>,
84    /// A timeout for only the connect phase of a request to GraphQL Hive
85    /// Unit: seconds
86    /// Default: 5 (s)
87    connect_timeout: Option<u64>,
88    /// A timeout for the entire request to GraphQL Hive
89    /// Unit: seconds
90    /// Default: 15 (s)
91    request_timeout: Option<u64>,
92    /// Accept invalid SSL certificates
93    /// Default: false
94    accept_invalid_certs: Option<bool>,
95    /// Frequency of flushing the buffer to the server
96    /// Default: 5 seconds
97    flush_interval: Option<u64>,
98}
99
100impl UsagePlugin {
101    fn populate_context(config: OperationConfig, req: &supergraph::Request) {
102        let context = &req.context;
103        let http_request = &req.supergraph_request;
104        let headers = http_request.headers();
105
106        let get_header_value = |key: &str| {
107            headers
108                .get(key)
109                .cloned()
110                .unwrap_or_else(|| HeaderValue::from_static(""))
111                .to_str()
112                .ok()
113                .map(|v| v.to_string())
114        };
115
116        let client_name = get_header_value(&config.client_name_header);
117        let client_version = get_header_value(&config.client_version_header);
118
119        let operation_name = req.supergraph_request.body().operation_name.clone();
120        let operation_body = req
121            .supergraph_request
122            .body()
123            .query
124            .clone()
125            .expect("operation body should not be empty");
126
127        let excluded_operation_names: HashSet<String> = config
128            .exclude
129            .unwrap_or_default()
130            .clone()
131            .into_iter()
132            .collect();
133
134        let mut rng = rand::rng();
135        let sampled = rng.random::<f64>() < config.sample_rate;
136        let mut dropped = !sampled;
137
138        if !dropped {
139            if let Some(name) = &operation_name {
140                if excluded_operation_names.contains(name) {
141                    dropped = true;
142                }
143            }
144        }
145
146        let _ = context.insert(
147            OPERATION_CONTEXT,
148            OperationContext {
149                dropped,
150                client_name,
151                client_version,
152                operation_name,
153                operation_body,
154                timestamp: SystemTime::now()
155                    .duration_since(UNIX_EPOCH)
156                    .unwrap()
157                    .as_secs()
158                    * 1000,
159            },
160        );
161    }
162}
163
164#[async_trait::async_trait]
165impl Plugin for UsagePlugin {
166    type Config = Config;
167
168    async fn new(init: PluginInit<Config>) -> Result<Self, BoxError> {
169        let user_config = init.config;
170
171        let enabled = user_config.enabled.unwrap_or(true);
172
173        if enabled {
174            tracing::info!("Starting GraphQL Hive Usage plugin");
175        }
176
177        let cancellation_token = Arc::new(CancellationToken::new());
178
179        let agent = if enabled {
180            let mut agent =
181                UsageAgent::builder().user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION));
182
183            if let Some(endpoint) = user_config.registry_usage_endpoint {
184                agent = agent.endpoint(endpoint);
185            } else if let Ok(env_endpoint) = env::var("HIVE_ENDPOINT") {
186                agent = agent.endpoint(env_endpoint);
187            }
188
189            if let Some(token) = user_config.registry_token {
190                agent = agent.token(token);
191            } else if let Ok(env_token) = env::var("HIVE_TOKEN") {
192                agent = agent.token(env_token);
193            }
194
195            if let Some(target_id) = user_config.target {
196                agent = agent.target_id(target_id);
197            } else if let Ok(env_target) = env::var("HIVE_TARGET_ID") {
198                agent = agent.target_id(env_target);
199            }
200
201            if let Some(buffer_size) = user_config.buffer_size {
202                agent = agent.buffer_size(buffer_size);
203            }
204
205            if let Some(connect_timeout) = user_config.connect_timeout {
206                agent = agent.connect_timeout(Duration::from_secs(connect_timeout));
207            }
208
209            if let Some(request_timeout) = user_config.request_timeout {
210                agent = agent.request_timeout(Duration::from_secs(request_timeout));
211            }
212
213            if let Some(accept_invalid_certs) = user_config.accept_invalid_certs {
214                agent = agent.accept_invalid_certs(accept_invalid_certs);
215            }
216
217            if let Some(flush_interval) = user_config.flush_interval {
218                agent = agent.flush_interval(Duration::from_secs(flush_interval));
219            }
220
221            let agent = agent.build().map_err(Box::new)?;
222
223            let cancellation_token_for_interval = cancellation_token.clone();
224            let agent_for_interval = agent.clone();
225            tokio::task::spawn(async move {
226                agent_for_interval
227                    .start_flush_interval(&cancellation_token_for_interval)
228                    .await;
229            });
230            Some(agent)
231        } else {
232            None
233        };
234
235        let schema = parse_schema(&init.supergraph_sdl)
236            .expect("Failed to parse schema")
237            .into_static();
238
239        Ok(UsagePlugin {
240            schema: Arc::new(schema),
241            config: OperationConfig {
242                sample_rate: user_config.sample_rate.unwrap_or(1.0),
243                exclude: user_config.exclude,
244                client_name_header: user_config
245                    .client_name_header
246                    .unwrap_or("graphql-client-name".to_string()),
247                client_version_header: user_config
248                    .client_version_header
249                    .unwrap_or("graphql-client-version".to_string()),
250            },
251            agent,
252            cancellation_token,
253        })
254    }
255
256    fn supergraph_service(&self, service: supergraph::BoxService) -> supergraph::BoxService {
257        let config = self.config.clone();
258        let schema = self.schema.clone();
259        match self.agent.clone() {
260            None => ServiceBuilder::new().service(service).boxed(),
261            Some(agent) => {
262                ServiceBuilder::new()
263                    .map_future_with_request_data(
264                        move |req: &supergraph::Request| {
265                            Self::populate_context(config.clone(), req);
266                            req.context.clone()
267                        },
268                        move |ctx: Context, fut| {
269                            let agent = agent.clone();
270                            let schema = schema.clone();
271                            async move {
272                                let start: Instant = Instant::now();
273
274                                // nested async block, bc async is unstable with closures that receive arguments
275                                let operation_context = ctx
276                                    .get::<_, OperationContext>(OPERATION_CONTEXT)
277                                    .unwrap_or_default()
278                                    .unwrap();
279
280                                // Injected by the persisted document plugin, if it was activated
281                                // and discovered document id
282                                let persisted_document_hash = ctx
283                                    .get::<_, String>(PERSISTED_DOCUMENT_HASH_KEY)
284                                    .ok()
285                                    .unwrap();
286
287                                let result: supergraph::ServiceResult = fut.await;
288
289                                if operation_context.dropped {
290                                    tracing::debug!(
291                                        "Dropping operation (phase: SAMPLING): {}",
292                                        operation_context
293                                            .operation_name
294                                            .clone()
295                                            .or_else(|| Some("anonymous".to_string()))
296                                            .unwrap()
297                                    );
298                                    return result;
299                                }
300
301                                let OperationContext {
302                                    client_name,
303                                    client_version,
304                                    operation_name,
305                                    timestamp,
306                                    operation_body,
307                                    ..
308                                } = operation_context;
309
310                                let duration = start.elapsed();
311
312                                match result {
313                                    Err(e) => {
314                                        tokio::spawn(async move {
315                                            let res = agent
316                                                .add_report(ExecutionReport {
317                                                    schema,
318                                                    client_name,
319                                                    client_version,
320                                                    timestamp,
321                                                    duration,
322                                                    ok: false,
323                                                    errors: 1,
324                                                    operation_body,
325                                                    operation_name,
326                                                    persisted_document_hash,
327                                                })
328                                                .await;
329                                            if let Err(e) = res {
330                                                tracing::error!("Error adding report: {}", e);
331                                            }
332                                        });
333                                        Err(e)
334                                    }
335                                    Ok(router_response) => {
336                                        let is_failure =
337                                            !router_response.response.status().is_success();
338                                        Ok(router_response.map(move |response_stream| {
339                                            let res = response_stream
340                                                .map(move |response| {
341                                                    // make sure we send a single report, not for each chunk
342                                                    let response_has_errors =
343                                                        !response.errors.is_empty();
344                                                    let agent = agent.clone();
345                                                    let execution_report = ExecutionReport {
346                                                        schema: schema.clone(),
347                                                        client_name: client_name.clone(),
348                                                        client_version: client_version.clone(),
349                                                        timestamp,
350                                                        duration,
351                                                        ok: !is_failure && !response_has_errors,
352                                                        errors: response.errors.len(),
353                                                        operation_body: operation_body.clone(),
354                                                        operation_name: operation_name.clone(),
355                                                        persisted_document_hash:
356                                                            persisted_document_hash.clone(),
357                                                    };
358                                                    tokio::spawn(async move {
359                                                        let res = agent
360                                                            .add_report(execution_report)
361                                                            .await;
362                                                        if let Err(e) = res {
363                                                            tracing::error!(
364                                                                "Error adding report: {}",
365                                                                e
366                                                            );
367                                                        }
368                                                    });
369
370                                                    response
371                                                })
372                                                .boxed();
373
374                                            res
375                                        }))
376                                    }
377                                }
378                            }
379                        },
380                    )
381                    .service(service)
382                    .boxed()
383            }
384        }
385    }
386}
387
388impl Drop for UsagePlugin {
389    fn drop(&mut self) {
390        self.cancellation_token.cancel();
391        // Flush already done by UsageAgent's Drop impl
392    }
393}
394
395#[cfg(test)]
396mod hive_usage_tests {
397    use apollo_router::{
398        plugin::{test::MockSupergraphService, Plugin, PluginInit},
399        services::supergraph,
400    };
401    use http::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
402    use httpmock::{Method::POST, Mock, MockServer};
403    use jsonschema::Validator;
404    use serde_json::json;
405    use tower::ServiceExt;
406
407    use crate::consts::PLUGIN_VERSION;
408
409    use super::{Config, UsagePlugin};
410
411    lazy_static::lazy_static! {
412        static ref SCHEMA_VALIDATOR: Validator =
413                jsonschema::validator_for(&serde_json::from_str(&std::fs::read_to_string("../../services/usage/usage-report-v2.schema.json").expect("can't load json schema file")).expect("failed to parse json schema")).expect("failed to parse schema");
414    }
415
416    struct UsageTestHelper {
417        mocked_upstream: MockServer,
418        plugin: UsagePlugin,
419    }
420
421    impl UsageTestHelper {
422        async fn new() -> Self {
423            let server: MockServer = MockServer::start();
424            let usage_endpoint = server.url("/usage");
425            let mut config = Config::default();
426            config.enabled = Some(true);
427            config.registry_usage_endpoint = Some(usage_endpoint.to_string());
428            config.registry_token = Some("123".into());
429            config.buffer_size = Some(1);
430            config.flush_interval = Some(1);
431
432            let plugin_service = UsagePlugin::new(
433                PluginInit::fake_builder()
434                    .config(config)
435                    .supergraph_sdl("type Query { dummy: String! }".to_string().into())
436                    .build(),
437            )
438            .await
439            .expect("failed to init plugin");
440
441            UsageTestHelper {
442                mocked_upstream: server,
443                plugin: plugin_service,
444            }
445        }
446
447        fn wait_for_processing(&self) -> tokio::time::Sleep {
448            tokio::time::sleep(tokio::time::Duration::from_secs(2))
449        }
450
451        fn activate_usage_mock(&'_ self) -> Mock<'_> {
452            self.mocked_upstream.mock(|when, then| {
453                when.method(POST)
454                    .path("/usage")
455                    .header(CONTENT_TYPE.as_str(), "application/json")
456                    .header(
457                        USER_AGENT.as_str(),
458                        format!("hive-apollo-router/{}", PLUGIN_VERSION),
459                    )
460                    .header(AUTHORIZATION.as_str(), "Bearer 123")
461                    .header("X-Usage-API-Version", "2")
462                    .matches(|r| {
463                        // This mock also validates that the content of the reported usage is valid
464                        // when it comes to the JSON schema validation.
465                        // if it does not match, the request matching will fail and this will lead
466                        // to a failed assertion
467                        let body = r.body.as_ref().unwrap();
468                        let body = String::from_utf8(body.to_vec()).unwrap();
469                        let body = serde_json::from_str(&body).unwrap();
470
471                        SCHEMA_VALIDATOR.is_valid(&body)
472                    });
473                then.status(200);
474            })
475        }
476
477        async fn execute_operation(&self, req: supergraph::Request) -> supergraph::Response {
478            let mut supergraph_service_mock = MockSupergraphService::new();
479
480            supergraph_service_mock
481                .expect_call()
482                .times(1)
483                .returning(move |_| {
484                    Ok(supergraph::Response::fake_builder()
485                        .data(json!({
486                            "data": { "hello": "world" },
487                        }))
488                        .build()
489                        .unwrap())
490                });
491
492            let tower_service = self
493                .plugin
494                .supergraph_service(supergraph_service_mock.boxed());
495
496            let response = tower_service
497                .oneshot(req)
498                .await
499                .expect("failed to execute operation");
500
501            response
502        }
503    }
504
505    #[tokio::test]
506    async fn should_work_correctly_for_simple_query() {
507        let instance = UsageTestHelper::new().await;
508        let req = supergraph::Request::fake_builder()
509            .query("query test { hello }")
510            .operation_name("test")
511            .build()
512            .unwrap();
513        let mock = instance.activate_usage_mock();
514
515        instance.execute_operation(req).await.next_response().await;
516
517        instance.wait_for_processing().await;
518
519        mock.assert();
520        mock.assert_hits(1);
521    }
522
523    #[tokio::test]
524    async fn without_operation_name() {
525        let instance = UsageTestHelper::new().await;
526        let req = supergraph::Request::fake_builder()
527            .query("query { hello }")
528            .build()
529            .unwrap();
530        let mock = instance.activate_usage_mock();
531
532        instance.execute_operation(req).await.next_response().await;
533
534        instance.wait_for_processing().await;
535
536        mock.assert();
537        mock.assert_hits(1);
538    }
539
540    #[tokio::test]
541    async fn multiple_operations() {
542        let instance = UsageTestHelper::new().await;
543        let req = supergraph::Request::fake_builder()
544            .query("query test { hello } query test2 { hello }")
545            .operation_name("test")
546            .build()
547            .unwrap();
548        let mock = instance.activate_usage_mock();
549
550        instance.execute_operation(req).await.next_response().await;
551
552        instance.wait_for_processing().await;
553        println!("Waiting done");
554
555        mock.assert();
556        mock.assert_hits(1);
557    }
558}