Skip to main content

client_core/
authenticated_client.rs

1use crate::{ClientRegisterRequest, database::Database};
2use anyhow::Result;
3use reqwest::{Client, Method, RequestBuilder, Response};
4use serde::Serialize;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tracing::{error, info, warn};
8
9/// 认证客户端包装器
10/// 自动处理client_id的设置和认证失败时的重新注册
11#[derive(Debug, Clone)]
12pub struct AuthenticatedClient {
13    client: Client,
14    database: Arc<Database>,
15    server_base_url: String,
16    client_id: Arc<RwLock<Option<String>>>,
17}
18
19impl AuthenticatedClient {
20    /// 创建新的认证客户端
21    pub async fn new(database: Arc<Database>, server_base_url: String) -> Result<Self> {
22        let client = Client::new();
23
24        // 从数据库获取当前的client_id
25        let client_id = database.get_client_id().await?;
26
27        Ok(Self {
28            client,
29            database,
30            server_base_url,
31            client_id: Arc::new(RwLock::new(client_id)),
32        })
33    }
34
35    /// 检查URL是否是我们的服务器
36    fn is_our_server(&self, url: &str) -> bool {
37        url.starts_with(&self.server_base_url)
38    }
39
40    /// 检查是否是注册接口(不需要认证)
41    fn is_register_endpoint(&self, url: &str) -> bool {
42        url.contains("/clients/register")
43    }
44
45    /// 获取当前的client_id
46    async fn get_client_id(&self) -> Option<String> {
47        self.client_id.read().await.clone()
48    }
49
50    /// 更新client_id
51    async fn set_client_id(&self, new_client_id: String) -> Result<()> {
52        // 更新内存中的值
53        *self.client_id.write().await = Some(new_client_id.clone());
54
55        // 保存到数据库
56        self.database.update_client_id(&new_client_id).await?;
57
58        Ok(())
59    }
60
61    /// 自动注册客户端
62    async fn auto_register(&self) -> Result<String> {
63        info!("Attempting to auto-register client...");
64
65        let request = ClientRegisterRequest {
66            os: std::env::consts::OS.to_string(),
67            arch: std::env::consts::ARCH.to_string(),
68        };
69
70        // 使用常量定义的注册端点
71        let register_url = format!(
72            "{}{}",
73            self.server_base_url,
74            crate::constants::api::endpoints::CLIENT_REGISTER
75        );
76        let response = self
77            .client
78            .post(&register_url)
79            .json(&request)
80            .send()
81            .await?;
82
83        if response.status().is_success() {
84            let register_response: serde_json::Value = response.json().await?;
85            if let Some(client_id) = register_response.get("client_id").and_then(|v| v.as_str()) {
86                let client_id = client_id.to_string();
87                info!("Auto-registration successful, client ID: {}", client_id);
88
89                // 保存新的client_id
90                self.set_client_id(client_id.clone()).await?;
91
92                Ok(client_id)
93            } else {
94                Err(anyhow::anyhow!("Invalid registration response format"))
95            }
96        } else {
97            let status = response.status();
98            let text = response.text().await.unwrap_or_default();
99            error!("Client registration failed: {} - {}", status, text);
100            Err(anyhow::anyhow!("Registration failed: {status} - {text}"))
101        }
102    }
103
104    /// 为请求添加认证头
105    async fn add_auth_header(
106        &self,
107        mut request_builder: RequestBuilder,
108        url: &str,
109    ) -> RequestBuilder {
110        // 只对我们的服务器且非注册接口添加认证头
111        if self.is_our_server(url)
112            && !self.is_register_endpoint(url)
113            && let Some(client_id) = self.get_client_id().await
114        {
115            request_builder = request_builder.header("X-Client-ID", client_id);
116        }
117        request_builder
118    }
119
120    /// 执行请求,自动处理认证
121    async fn execute_request(&self, method: Method, url: &str) -> Result<RequestBuilder> {
122        let request_builder = self.client.request(method, url);
123        Ok(self.add_auth_header(request_builder, url).await)
124    }
125
126    /// 执行带JSON body的请求
127    async fn execute_request_with_json<T: Serialize>(
128        &self,
129        method: Method,
130        url: &str,
131        json: &T,
132    ) -> Result<RequestBuilder> {
133        let request_builder = self.client.request(method, url).json(json);
134        Ok(self.add_auth_header(request_builder, url).await)
135    }
136
137    /// 发送请求并处理认证失败
138    async fn send_with_retry(
139        &self,
140        request_builder: RequestBuilder,
141        original_url: &str,
142    ) -> Result<Response> {
143        let response = request_builder.send().await?;
144
145        // 检查是否是认证失败
146        if response.status() == reqwest::StatusCode::UNAUTHORIZED
147            && self.is_our_server(original_url)
148            && !self.is_register_endpoint(original_url)
149        {
150            warn!("API request authentication failed (401), attempting to auto-re-register...");
151
152            // 尝试自动注册
153            match self.auto_register().await {
154                Ok(new_client_id) => {
155                    info!(
156                        "Auto re-registration successful, client ID: {}, retrying request...",
157                        new_client_id
158                    );
159
160                    // 重新从头构建请求,使用新的client_id
161                    // 我们需要重新创建请求,因为原来的RequestBuilder已经被消费
162                    let retry_request_builder = self
163                        .client
164                        .get(original_url)
165                        .header("X-Client-ID", new_client_id);
166
167                    let retry_response = retry_request_builder.send().await?;
168                    Ok(retry_response)
169                }
170                Err(e) => {
171                    error!("Auto re-registration failed: {}", e);
172                    Err(anyhow::anyhow!(
173                        "Authentication failed and unable to re-register: {e}"
174                    ))
175                }
176            }
177        } else {
178            Ok(response)
179        }
180    }
181
182    /// GET请求
183    pub async fn get(&self, url: &str) -> Result<RequestBuilder> {
184        self.execute_request(Method::GET, url).await
185    }
186
187    /// POST请求
188    pub async fn post(&self, url: &str) -> Result<RequestBuilder> {
189        self.execute_request(Method::POST, url).await
190    }
191
192    /// PUT请求
193    pub async fn put(&self, url: &str) -> Result<RequestBuilder> {
194        self.execute_request(Method::PUT, url).await
195    }
196
197    /// DELETE请求
198    pub async fn delete(&self, url: &str) -> Result<RequestBuilder> {
199        self.execute_request(Method::DELETE, url).await
200    }
201
202    /// POST请求(带JSON)
203    pub async fn post_json<T: Serialize>(&self, url: &str, json: &T) -> Result<Response> {
204        let request_builder = self
205            .execute_request_with_json(Method::POST, url, json)
206            .await?;
207        self.send_with_retry(request_builder, url).await
208    }
209
210    /// PUT请求(带JSON)
211    pub async fn put_json<T: Serialize>(&self, url: &str, json: &T) -> Result<Response> {
212        let request_builder = self
213            .execute_request_with_json(Method::PUT, url, json)
214            .await?;
215        self.send_with_retry(request_builder, url).await
216    }
217
218    /// 发送请求(通用方法)
219    pub async fn send(&self, request_builder: RequestBuilder, url: &str) -> Result<Response> {
220        self.send_with_retry(request_builder, url).await
221    }
222
223    /// 获取原始的reqwest客户端(用于特殊情况)
224    pub fn inner(&self) -> &Client {
225        &self.client
226    }
227
228    /// 获取当前的client_id(只读)
229    pub async fn current_client_id(&self) -> Option<String> {
230        self.get_client_id().await
231    }
232}