m10_sdk/
oauth_interceptor.rs

1use reqwest::header::HeaderMap;
2use reqwest::header::HeaderValue;
3use reqwest::Client;
4use serde::Deserialize;
5use std::fs;
6use std::sync::Arc;
7use std::sync::Mutex;
8use toml;
9use tonic::{service::Interceptor, Request, Status};
10
11#[derive(Clone)]
12pub struct OauthInterceptor {
13    access_token: Arc<Mutex<Option<String>>>,
14}
15
16#[derive(Deserialize, Clone)]
17pub struct Config {
18    oauth: OauthConfig,
19}
20
21#[derive(Deserialize, Clone)]
22pub struct OauthConfig {
23    client_id: String,
24    client_secret: String,
25    base_url: String,
26}
27
28#[derive(Deserialize)]
29struct TokenResponse {
30    access_token: String,
31}
32
33fn load_config(filepath: &str) -> Result<Config, std::io::Error> {
34    let contents = fs::read_to_string(filepath)?;
35    toml::from_str(&contents).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
36}
37
38// TODO: Add handling for expired tokens
39/// Handles fetching and adding access tokens to requests. If config is not present,
40/// or any errors occur with the interceptor, this will pass any gRPC requests without
41/// modification.
42impl OauthInterceptor {
43    pub fn new() -> Self {
44        let interceptor = Self {
45            access_token: Arc::new(Mutex::new(None)),
46        };
47
48        let interceptor_clone = interceptor.clone();
49        tokio::spawn(async move {
50            if let Ok(cfg) = load_config("config.toml") {
51                let _ = interceptor_clone
52                    .set_token_from_http(
53                        format!("{}/accesstoken", cfg.oauth.base_url,).as_str(),
54                        &cfg.oauth.client_id,
55                        &cfg.oauth.client_secret,
56                    )
57                    .await;
58            }
59        });
60
61        interceptor
62    }
63
64    pub fn set_token(&self, token: &str) {
65        let mut token_guard = self.access_token.lock().unwrap();
66        *token_guard = Some(token.into());
67    }
68
69    pub fn clear_token(&self) {
70        let mut token_guard = self.access_token.lock().unwrap();
71        *token_guard = None;
72    }
73
74    fn get_token(&self) -> Option<String> {
75        self.access_token.lock().unwrap().clone()
76    }
77
78    pub async fn set_token_from_http(
79        &self,
80        url: &str,
81        client_id: &str,
82        client_secret: &str,
83    ) -> Result<(), reqwest::Error> {
84        let client = Client::new();
85
86        let mut headers = HeaderMap::new();
87        headers.insert(
88            "X-SunGard-IdP-API-Key",
89            HeaderValue::from_str("SunGard-IdP-UI").unwrap(),
90        );
91        headers.insert("Accept", HeaderValue::from_str("application/json").unwrap());
92        headers.insert(
93            "Content-Type",
94            HeaderValue::from_str("application/x-www-form-urlencoded").unwrap(),
95        );
96
97        let encoded_params = serde_urlencoded::to_string(&[
98            ("client_id", client_id),
99            ("client_secret", client_secret),
100        ])
101        .unwrap();
102
103        let response = client
104            .post(url)
105            .headers(headers)
106            .body(encoded_params)
107            .send()
108            .await?
109            .error_for_status()?;
110
111        let token_response: TokenResponse = response.json().await?;
112        self.set_token(&token_response.access_token);
113        Ok(())
114    }
115}
116
117impl Interceptor for OauthInterceptor {
118    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
119        if let Some(token) = self.get_token() {
120            request.metadata_mut().insert(
121                "authorization",
122                format!("Bearer {}", token).parse().unwrap(),
123            );
124            Ok(request)
125        } else {
126            // pass-through if access token is not initialized
127            Ok(request)
128        }
129    }
130}