1use crate::error;
2use crate::subscription::Subscription;
3use crate::topic::Topic;
4use goauth::auth::JwtClaims;
5use goauth::scopes::Scope;
6use hyper::client::HttpConnector;
7use hyper_tls::HttpsConnector;
8use smpl_jwt::Jwt;
9use std::fs;
10use std::str::FromStr;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, RwLock};
13use std::time::Duration;
14use tokio::task;
15use tokio::time;
16
17type HyperClient = Arc<hyper::Client<HttpsConnector<HttpConnector>, hyper::Body>>;
18
19pub struct State {
20 token: Option<goauth::auth::Token>,
21 credentials_string: String,
22 project: Option<String>,
23 hyper_client: HyperClient,
24 running: Arc<AtomicBool>,
25}
26
27impl State {
28 pub fn token_type(&self) -> &str {
29 self.token.as_ref().unwrap().token_type()
30 }
31
32 pub fn access_token(&self) -> &str {
33 self.token.as_ref().unwrap().access_token()
34 }
35
36 pub fn project(&self) -> &str {
37 &(self.project.as_ref().expect("Google Cloud Project has not been set. If it is not in your credential file, call set_project to set it manually."))
38 }
39}
40
41pub struct Client(Arc<RwLock<State>>);
42
43impl Clone for Client {
44 fn clone(&self) -> Self {
45 Client(self.0.clone())
46 }
47}
48
49impl Client {
50 pub async fn from_string(credentials_string: String) -> Result<Self, error::Error> {
51 let mut client = Client(Arc::new(RwLock::new(State {
52 token: None,
53 credentials_string,
54 project: None,
55 hyper_client: setup_hyper(),
56 running: Arc::new(AtomicBool::new(true)),
57 })));
58
59 match client.refresh_token().await {
60 Ok(_) => Ok(client),
61 Err(e) => Err(e),
62 }
63 }
64
65 pub async fn new(credentials_path: String) -> Result<Self, error::Error> {
66 let credentials_string = fs::read_to_string(credentials_path).unwrap();
67 Self::from_string(credentials_string).await
68 }
69
70 pub fn subscribe(&self, name: String) -> Subscription {
71 Subscription {
72 client: Some(self.clone()),
73 name: format!("projects/{}/subscriptions/{}", self.project(), name),
74 topic: None,
75 }
76 }
77
78 pub fn set_project(&mut self, project: String) {
79 self.0.write().unwrap().project = Some(project);
80 }
81
82 pub fn project(&self) -> String {
83 self.0.read().unwrap().project().to_string()
84 }
85
86 pub fn topic(&self, name: String) -> Topic {
87 Topic {
88 client: Some(Client(self.0.clone())),
89 name: format!("projects/{}/topics/{}", self.project(), name),
90 }
91 }
92
93 pub fn is_running(&self) -> bool {
94 self.0.read().unwrap().running.load(Ordering::SeqCst)
95 }
96
97 pub fn stop(&self) {
98 self.0
99 .write()
100 .unwrap()
101 .running
102 .store(false, Ordering::SeqCst)
103 }
104
105 pub fn spawn_token_renew(&self, interval: Duration) {
106 let mut client = self.clone();
107 let c = self.clone();
108 let renew_token_task = async move {
109 let mut int = time::interval(interval);
110 loop {
111 if c.is_running() {
112 int.tick().await;
113 log::debug!("Renewing pubsub token");
114 if let Err(e) = client.refresh_token().await {
115 log::error!("Failed to update token: {}", e);
116 }
117 }
118 }
119 };
120
121 task::spawn(renew_token_task);
122 }
123
124 pub async fn refresh_token(&mut self) -> Result<(), error::Error> {
125 match self.get_token().await {
126 Ok(token) => {
127 self.0.write().unwrap().token = Some(token);
128 Ok(())
129 }
130 Err(e) => Err(error::Error::from(e)),
131 }
132 }
133
134 async fn get_token(&mut self) -> Result<goauth::auth::Token, goauth::GoErr> {
135 let credentials =
136 goauth::credentials::Credentials::from_str(&self.0.read().unwrap().credentials_string)
137 .unwrap();
138
139 self.set_project(credentials.project());
140
141 let claims = JwtClaims::new(
142 credentials.iss(),
143 &Scope::PubSub,
144 credentials.token_uri(),
145 None,
146 None,
147 );
148 let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None);
149 goauth::get_token(&jwt, &credentials).await
150 }
151
152 pub(crate) fn request<T: Into<hyper::Body>>(
153 &self,
154 method: hyper::Method,
155 data: T,
156 ) -> hyper::Request<hyper::Body>
157 where
158 hyper::Body: std::convert::From<T>,
159 {
160 let mut req = hyper::Request::new(hyper::Body::from(data));
161 *req.method_mut() = method;
162 req.headers_mut().insert(
163 hyper::header::CONTENT_TYPE,
164 hyper::header::HeaderValue::from_static("application/json"),
165 );
166 let readable = self.0.read().unwrap();
167 req.headers_mut().insert(
168 hyper::header::AUTHORIZATION,
169 hyper::header::HeaderValue::from_str(&format!(
170 "{} {}",
171 readable.token_type(),
172 readable.access_token()
173 ))
174 .unwrap(),
175 );
176 req
177 }
178
179 pub fn hyper_client(&self) -> HyperClient {
180 self.0.read().unwrap().hyper_client.clone()
181 }
182}
183
184fn setup_hyper() -> HyperClient {
185 let https = HttpsConnector::new();
186 Arc::new(hyper::Client::builder().build::<_, hyper::Body>(https))
187}