cloud_pubsub/
client.rs

1use crate::error;
2use crate::subscription::Subscription;
3use crate::topic::Topic;
4use goauth::auth::JwtClaims;
5use goauth::scopes::Scope;
6use hyper::client::HttpConnector;
7use hyper_tls::HttpsConnector;
8use smpl_jwt::Jwt;
9use std::fs;
10use std::str::FromStr;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, RwLock};
13use std::time::Duration;
14use tokio::task;
15use tokio::time;
16
17type HyperClient = Arc<hyper::Client<HttpsConnector<HttpConnector>, hyper::Body>>;
18
19pub struct State {
20    token: Option<goauth::auth::Token>,
21    credentials_string: String,
22    project: Option<String>,
23    hyper_client: HyperClient,
24    running: Arc<AtomicBool>,
25}
26
27impl State {
28    pub fn token_type(&self) -> &str {
29        self.token.as_ref().unwrap().token_type()
30    }
31
32    pub fn access_token(&self) -> &str {
33        self.token.as_ref().unwrap().access_token()
34    }
35
36    pub fn project(&self) -> &str {
37        &(self.project.as_ref().expect("Google Cloud Project has not been set. If it is not in your credential file, call set_project to set it manually."))
38    }
39}
40
41pub struct Client(Arc<RwLock<State>>);
42
43impl Clone for Client {
44    fn clone(&self) -> Self {
45        Client(self.0.clone())
46    }
47}
48
49impl Client {
50    pub async fn from_string(credentials_string: String) -> Result<Self, error::Error> {
51        let mut client = Client(Arc::new(RwLock::new(State {
52            token: None,
53            credentials_string,
54            project: None,
55            hyper_client: setup_hyper(),
56            running: Arc::new(AtomicBool::new(true)),
57        })));
58
59        match client.refresh_token().await {
60            Ok(_) => Ok(client),
61            Err(e) => Err(e),
62        }
63    }
64
65    pub async fn new(credentials_path: String) -> Result<Self, error::Error> {
66        let credentials_string = fs::read_to_string(credentials_path).unwrap();
67        Self::from_string(credentials_string).await
68    }
69
70    pub fn subscribe(&self, name: String) -> Subscription {
71        Subscription {
72            client: Some(self.clone()),
73            name: format!("projects/{}/subscriptions/{}", self.project(), name),
74            topic: None,
75        }
76    }
77
78    pub fn set_project(&mut self, project: String) {
79        self.0.write().unwrap().project = Some(project);
80    }
81
82    pub fn project(&self) -> String {
83        self.0.read().unwrap().project().to_string()
84    }
85
86    pub fn topic(&self, name: String) -> Topic {
87        Topic {
88            client: Some(Client(self.0.clone())),
89            name: format!("projects/{}/topics/{}", self.project(), name),
90        }
91    }
92
93    pub fn is_running(&self) -> bool {
94        self.0.read().unwrap().running.load(Ordering::SeqCst)
95    }
96
97    pub fn stop(&self) {
98        self.0
99            .write()
100            .unwrap()
101            .running
102            .store(false, Ordering::SeqCst)
103    }
104
105    pub fn spawn_token_renew(&self, interval: Duration) {
106        let mut client = self.clone();
107        let c = self.clone();
108        let renew_token_task = async move {
109            let mut int = time::interval(interval);
110            loop {
111                if c.is_running() {
112                    int.tick().await;
113                    log::debug!("Renewing pubsub token");
114                    if let Err(e) = client.refresh_token().await {
115                        log::error!("Failed to update token: {}", e);
116                    }
117                }
118            }
119        };
120
121        task::spawn(renew_token_task);
122    }
123
124    pub async fn refresh_token(&mut self) -> Result<(), error::Error> {
125        match self.get_token().await {
126            Ok(token) => {
127                self.0.write().unwrap().token = Some(token);
128                Ok(())
129            }
130            Err(e) => Err(error::Error::from(e)),
131        }
132    }
133
134    async fn get_token(&mut self) -> Result<goauth::auth::Token, goauth::GoErr> {
135        let credentials =
136            goauth::credentials::Credentials::from_str(&self.0.read().unwrap().credentials_string)
137                .unwrap();
138
139        self.set_project(credentials.project());
140
141        let claims = JwtClaims::new(
142            credentials.iss(),
143            &Scope::PubSub,
144            credentials.token_uri(),
145            None,
146            None,
147        );
148        let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None);
149        goauth::get_token(&jwt, &credentials).await
150    }
151
152    pub(crate) fn request<T: Into<hyper::Body>>(
153        &self,
154        method: hyper::Method,
155        data: T,
156    ) -> hyper::Request<hyper::Body>
157    where
158        hyper::Body: std::convert::From<T>,
159    {
160        let mut req = hyper::Request::new(hyper::Body::from(data));
161        *req.method_mut() = method;
162        req.headers_mut().insert(
163            hyper::header::CONTENT_TYPE,
164            hyper::header::HeaderValue::from_static("application/json"),
165        );
166        let readable = self.0.read().unwrap();
167        req.headers_mut().insert(
168            hyper::header::AUTHORIZATION,
169            hyper::header::HeaderValue::from_str(&format!(
170                "{} {}",
171                readable.token_type(),
172                readable.access_token()
173            ))
174            .unwrap(),
175        );
176        req
177    }
178
179    pub fn hyper_client(&self) -> HyperClient {
180        self.0.read().unwrap().hyper_client.clone()
181    }
182}
183
184fn setup_hyper() -> HyperClient {
185    let https = HttpsConnector::new();
186    Arc::new(hyper::Client::builder().build::<_, hyper::Body>(https))
187}