1use oauth2::basic::BasicClient;
25use oauth2::http::{HeaderMap, HeaderValue, Method};
26use oauth2::{
27 AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
28 PkceCodeVerifier, RedirectUrl, TokenResponse, TokenUrl,
29};
30use serde::{Deserialize, Serialize};
31use std::io::{Read, Write};
32use std::net::{TcpListener, TcpStream};
33use std::time::{Duration, SystemTime};
34use url::Url;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TokenDetails {
45 pub api_token: String,
47 pub expires_at: SystemTime,
49}
50
51#[derive(Debug, Clone)]
53pub struct AuthorizationInfo {
54 pub url: Url,
56 pub pkce_verifier: PkceVerifier,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct PkceVerifier(String);
63
64impl PkceVerifier {
65 pub fn new(secret: String) -> Self {
67 Self(secret)
68 }
69}
70
71pub struct Authenticator {
75 client_id: ClientId,
76 client_secret: ClientSecret,
77 port: u16,
78}
79
80impl Authenticator {
81 pub fn new(client_id: String, client_secret: String) -> Self {
88 Self {
89 client_id: ClientId::new(client_id),
90 client_secret: ClientSecret::new(client_secret),
91 port: 8080,
92 }
93 }
94
95 pub fn with_redirect_server_port(mut self, port: u16) -> Self {
99 self.port = port;
100 self
101 }
102
103 pub fn authorization_url(&self) -> Result<AuthorizationInfo, Box<dyn std::error::Error>> {
109 let redirect_url = format!("http://localhost:{}", self.port);
110 let client = self.client(&redirect_url)?;
111
112 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
114
115 let (auth_url, _csrf_token) = client
117 .authorize_url(CsrfToken::new_random)
118 .set_pkce_challenge(pkce_challenge)
119 .url();
120
121 Ok(AuthorizationInfo {
122 url: auth_url,
123 pkce_verifier: PkceVerifier::new(pkce_verifier.secret().to_string()),
124 })
125 }
126
127 pub async fn exchange_code(
133 &self,
134 code: AuthorizationCode,
135 pkce_verifier: PkceVerifier,
136 ) -> Result<TokenDetails, Box<dyn std::error::Error>> {
137 let redirect_url = format!("http://localhost:{}", self.port);
138 let client = self.client(&redirect_url)?;
139 let pkce_verifier = PkceCodeVerifier::new(pkce_verifier.0);
140
141 let token_response = client
143 .exchange_code(code)
144 .set_pkce_verifier(pkce_verifier)
145 .request_async(oauth2::reqwest::async_http_client)
146 .await?;
147
148 let token_string = token_response.access_token().secret();
149 log::info!("OAuth access token: {token_string}");
150
151 let response = self.fetch_api_token(token_string).await?;
153 log::info!("OAuth API token: {}", response.api_token);
154
155 let expires_at = SystemTime::now() + Duration::from_secs(24 * 60 * 60);
157 Ok(TokenDetails {
158 api_token: response.api_token,
159 expires_at,
160 })
161 }
162
163 pub fn listen_for_redirect(&self) -> Result<AuthorizationCode, Box<dyn std::error::Error>> {
170 let listener = TcpListener::bind(("127.0.0.1", self.port))?;
172
173 self.listen_for_code(listener)
175 }
176
177 pub async fn get_api_token(self) -> Result<TokenDetails, Box<dyn std::error::Error>> {
183 let auth_info = self.authorization_url()?;
185
186 log::info!("Opening browser to: {}", auth_info.url);
188 opener::open(auth_info.url.to_string())?;
189
190 let code = self.listen_for_redirect()?;
192
193 self.exchange_code(code, auth_info.pkce_verifier).await
195 }
196
197 fn listen_for_code(
199 &self,
200 listener: TcpListener,
201 ) -> Result<AuthorizationCode, Box<dyn std::error::Error>> {
202 for stream in listener.incoming() {
203 match stream {
204 Ok(stream) => {
205 if let Ok(code) = self.handle_connection(stream) {
207 return Ok(code);
208 }
209 }
213 Err(e) => {
214 log::error!("Failed to accept connection: {e}");
215 }
216 }
217 }
218 Err("Server closed before receiving authorization code".into())
219 }
220
221 fn handle_connection(
223 &self,
224 mut stream: TcpStream,
225 ) -> Result<AuthorizationCode, Box<dyn std::error::Error>> {
226 let mut buffer = [0; 1024];
227 stream.read(&mut buffer)?;
228
229 match self.parse_code_from_request(&buffer) {
230 Ok(code) => {
231 let message = "<h1>Success!</h1><p>You can close this window now.</p>";
232 self.send_response(&mut stream, "200 OK", message)?;
233 Ok(code)
234 }
235 Err(e) => {
236 log::error!("Failed to parse code from request: {e}");
237 let message = "<h1>Error!</h1><p>Could not get authorization code from iNaturalist. Please try again.</p>";
238 self.send_response(&mut stream, "400 Bad Request", message)?;
239 Err(e)
240 }
241 }
242 }
243
244 fn parse_code_from_request(
246 &self,
247 buffer: &[u8],
248 ) -> Result<AuthorizationCode, Box<dyn std::error::Error>> {
249 let mut headers = [httparse::EMPTY_HEADER; 64];
250 let mut req = httparse::Request::new(&mut headers);
251 req.parse(buffer)?;
252
253 let path = req.path.ok_or("Malformed request: no path")?;
254 let url = Url::parse(&format!("http://localhost{path}"))?;
255
256 let code_pair = url
257 .query_pairs()
258 .find(|pair| pair.0 == "code")
259 .ok_or_else(|| format!("URL did not contain 'code' parameter: {url}"))?;
260
261 Ok(AuthorizationCode::new(code_pair.1.into_owned()))
262 }
263
264 fn send_response(
266 &self,
267 stream: &mut TcpStream,
268 status: &str,
269 body: &str,
270 ) -> std::io::Result<()> {
271 let response = format!(
272 "HTTP/1.1 {}\r\ncontent-length: {}\r\n\r\n{}",
273 status,
274 body.len(),
275 body
276 );
277 stream.write_all(response.as_bytes())
278 }
279
280 async fn fetch_api_token(
282 &self,
283 token_string: &str,
284 ) -> Result<ApiTokenResponse, Box<dyn std::error::Error>> {
285 let mut headers = HeaderMap::new();
286 headers.append(
287 "Authorization",
288 HeaderValue::from_str(&format!("Bearer {token_string}"))?,
289 );
290 let request = oauth2::HttpRequest {
291 body: vec![],
292 headers,
293 url: "https://www.inaturalist.org/users/api_token".try_into()?,
294 method: Method::GET,
295 };
296
297 let response = oauth2::reqwest::async_http_client(request).await?;
298 Ok(serde_json::from_slice(&response.body)?)
299 }
300
301 fn client(&self, redirect_url: &str) -> Result<BasicClient, Box<dyn std::error::Error>> {
303 let auth_url = AuthUrl::new("https://www.inaturalist.org/oauth/authorize".to_string())?;
304 let token_url = TokenUrl::new("https://www.inaturalist.org/oauth/token".to_string())?;
305
306 Ok(BasicClient::new(
307 self.client_id.clone(),
308 Some(self.client_secret.clone()),
309 auth_url,
310 Some(token_url),
311 )
312 .set_redirect_uri(RedirectUrl::new(redirect_url.to_string())?))
313 }
314}
315
316#[derive(Deserialize)]
317struct ApiTokenResponse {
318 api_token: String,
319}