smbpndk_cli/account/
lib.rs

1use super::{model::User, signup::GithubEmail};
2use anyhow::{anyhow, Result};
3use console::style;
4use log::debug;
5use regex::Regex;
6use reqwest::{Client, Response, StatusCode};
7use serde::{Deserialize, Serialize};
8use serde_repr::Deserialize_repr;
9use smbpndk_networking::{
10    constants::{GH_OAUTH_CLIENT_ID, GH_OAUTH_REDIRECT_HOST, GH_OAUTH_REDIRECT_PORT},
11    smb_base_url_builder,
12};
13use spinners::Spinner;
14use std::{
15    fmt::{Display, Formatter},
16    fs::{create_dir_all, OpenOptions},
17    io::{BufRead, BufReader, Write},
18    net::{TcpListener, TcpStream},
19    sync::mpsc::{self, Receiver, Sender},
20};
21use url_builder::URLBuilder;
22
23// This is smb authorization model.
24#[derive(Debug, Serialize, Deserialize)]
25pub struct SmbAuthorization {
26    pub message: String,
27    pub user: Option<User>,
28    pub user_email: Option<GithubEmail>,
29    pub user_info: Option<GithubInfo>,
30    pub error_code: Option<ErrorCode>,
31}
32
33#[derive(Debug, serde_repr::Serialize_repr, Deserialize_repr, PartialEq)]
34#[repr(u32)]
35pub enum ErrorCode {
36    EmailNotFound = 1000,
37    EmailUnverified = 1001,
38    PasswordNotSet = 1003,
39    GithubNotLinked = 1004,
40}
41
42impl Display for ErrorCode {
43    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44        match self {
45            ErrorCode::EmailNotFound => write!(f, "Email not found."),
46            ErrorCode::EmailUnverified => write!(f, "Email not verified."),
47            ErrorCode::PasswordNotSet => write!(f, "Password not set."),
48            ErrorCode::GithubNotLinked => write!(f, "Github not connected."),
49        }
50    }
51}
52
53impl Copy for ErrorCode {}
54
55impl Clone for ErrorCode {
56    fn clone(&self) -> Self {
57        *self
58    }
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62pub struct GithubInfo {
63    pub id: i64,
64    pub login: String,
65    pub name: String,
66    pub avatar_url: String,
67    pub html_url: String,
68    pub email: Option<String>,
69    pub created_at: String,
70    pub updated_at: String,
71}
72
73pub async fn authorize_github() -> Result<SmbAuthorization> {
74    // Spin up a simple localhost server to listen for the GitHub OAuth callback
75    // setup_oauth_callback_server();
76    // Open the GitHub OAuth URL in the user's browser
77    let mut spinner = Spinner::new(
78        spinners::Spinners::BouncingBall,
79        style("🚀 Getting your GitHub information...")
80            .green()
81            .bold()
82            .to_string(),
83    );
84
85    let rx = match open::that(build_github_oauth_url()) {
86        Ok(_) => {
87            let (tx, rx): (Sender<String>, Receiver<String>) = mpsc::channel();
88            debug!(
89                "Setting up OAuth callback server... (tx: {:#?}, rx: {:#?})",
90                &tx, &rx
91            );
92            tokio::spawn(async move {
93                setup_oauth_callback_server(tx);
94            });
95            rx
96        }
97        Err(_) => {
98            let error = anyhow!("Failed to open a browser.");
99            return Err(error);
100        }
101    };
102
103    spinner.stop_and_persist("⌛", "Waiting for the authorization.".into());
104
105    debug!("Waiting for code from channel...");
106
107    match rx.recv() {
108        Ok(code) => {
109            debug!("Got code from channel: {:#?}", &code);
110            //Err(anyhow!("Failed to get code from channel."))
111            process_connect_github(code).await
112        }
113        Err(e) => {
114            let error = anyhow!("Failed to get code from channel: {e}");
115            Err(error)
116        }
117    }
118}
119
120fn setup_oauth_callback_server(tx: Sender<String>) {
121    let listener = TcpListener::bind(format!("127.0.0.1:{}", GH_OAUTH_REDIRECT_PORT)).unwrap();
122    for stream in listener.incoming() {
123        let stream = stream.unwrap();
124        handle_connection(stream, tx.clone());
125    }
126}
127
128fn handle_connection(mut stream: TcpStream, tx: Sender<String>) {
129    let buf_reader = BufReader::new(&stream);
130    let request_line = &buf_reader.lines().next().unwrap().unwrap();
131
132    debug!("Request: {:#?}", request_line);
133
134    let code_regex = Regex::new(r"code=([^&]*)").unwrap();
135
136    let (status_line, contents) = match code_regex.captures(request_line) {
137        Some(group) => {
138            let code = group.get(1).unwrap().as_str();
139            debug!("Code: {:#?}", code);
140            debug!("Sending code to channel...");
141            debug!("Channel: {:#?}", &tx);
142            match tx.send(code.to_string()) {
143                Ok(_) => {
144                    debug!("Code sent to channel.");
145                }
146                Err(e) => {
147                    debug!("Failed to send code to channel: {e}");
148                }
149            }
150            (
151                "HTTP/1.1 200 OK",
152                "<!DOCTYPE html>
153
154                <head>
155                    <meta charset='utf-8'>
156                    <title>Hello!</title>
157                </head>
158                
159                <body>
160                    <h1>Authenticated!</h1>
161                    <p>Back to the terminal console to finish your registration.</p>
162                </body>",
163            )
164        }
165        None => {
166            debug!("Code not found.");
167            (
168                "HTTP/1.1 404 NOT FOUND",
169                "<!DOCTYPE html>
170                <html lang='en'>
171                
172                <head>
173                    <meta charset='utf-8'>
174                    <title>404 Not found</title>
175                </head>
176                
177                <body>
178                    <h1>Oops!</h1>
179                    <p>Sorry, I don't know what you're asking for.</p>
180                </body>
181                
182                </html>",
183            )
184        }
185    };
186
187    debug!("Contents: {:#?}", &contents);
188    let response = format!("{status_line}\r\n\r\n{contents}");
189    stream.write_all(response.as_bytes()).unwrap();
190    stream.flush().unwrap();
191}
192
193// Get access token
194pub async fn process_connect_github(code: String) -> Result<SmbAuthorization> {
195    let response = Client::new()
196        .post(build_authorize_smb_url())
197        .body(format!("gh_code={}", code))
198        .header("Accept", "application/json")
199        .header("Content-Type", "application/x-www-form-urlencoded")
200        .send()
201        .await?;
202    let mut spinner = Spinner::new(
203        spinners::Spinners::BouncingBall,
204        style("🚀 Authorizing your account...")
205            .green()
206            .bold()
207            .to_string(),
208    );
209    // println!("Response: {:#?}", &response);
210    match response.status() {
211        StatusCode::OK => {
212            // Account authorized and token received
213            spinner.stop_and_persist("✅", "You are logged in with your GitHub account!".into());
214            save_token(&response).await?;
215            let result = response.json().await?;
216            // println!("Result: {:#?}", &result);
217            Ok(result)
218        }
219        StatusCode::NOT_FOUND => {
220            // Account not found and we show signup option
221            spinner.stop_and_persist("🥲", "Account not found. Please signup!".into());
222            let result = response.json().await?;
223            // println!("Result: {:#?}", &result);
224            Ok(result)
225        }
226        StatusCode::UNPROCESSABLE_ENTITY => {
227            // Account found but email not verified
228            spinner.stop_and_persist("🥹", "Unverified email!".into());
229            let result = response.json().await?;
230            // println!("Result: {:#?}", &result);
231            Ok(result)
232        }
233        _ => {
234            // Other errors
235            let error = anyhow!("Error while authorizing with GitHub.");
236            Err(error)
237        }
238    }
239}
240
241fn build_authorize_smb_url() -> String {
242    let mut url_builder = smb_base_url_builder();
243    url_builder.add_route("v1/authorize");
244    url_builder.build()
245}
246
247fn build_github_oauth_url() -> String {
248    let mut url_builder = github_base_url_builder();
249    url_builder
250        .add_route("login/oauth/authorize")
251        .add_param("scope", "user")
252        .add_param("state", "smbpndk");
253    url_builder.build()
254}
255
256fn github_base_url_builder() -> URLBuilder {
257    let redirect_url = format!("{}:{}", GH_OAUTH_REDIRECT_HOST, GH_OAUTH_REDIRECT_PORT);
258
259    let mut url_builder = URLBuilder::new();
260    url_builder
261        .set_protocol("https")
262        .set_host("github.com")
263        .add_param("client_id", GH_OAUTH_CLIENT_ID)
264        .add_param("redirect_uri", &redirect_url);
265    url_builder
266}
267
268pub async fn save_token(response: &Response) -> Result<()> {
269    let headers = response.headers();
270    // println!("Headers: {:#?}", &headers);
271    match headers.get("Authorization") {
272        Some(token) => {
273            debug!("{}", token.to_str()?);
274            match home::home_dir() {
275                Some(path) => {
276                    debug!("{}", path.to_str().unwrap());
277                    create_dir_all(path.join(".smb"))?;
278                    let mut file = OpenOptions::new()
279                        .create(true)
280                        .write(true)
281                        .open([path.to_str().unwrap(), "/.smb/token"].join(""))?;
282                    file.write_all(token.to_str()?.as_bytes())?;
283                    Ok(())
284                }
285                None => Err(anyhow!("Failed to get home directory.")),
286            }
287        }
288        None => Err(anyhow!("Failed to get token. Probably a backend issue.")),
289    }
290}