quic_rpc_utils/
lib.rs

1//! # RPC 工具和结构体
2//!
3//! 这个模块包含了一些用于处理RPC(远程过程调用)的实用工具和结构体。
4//! 它使用了多个库,包括`anyhow`、`flume`、`futures-lite`、`futures-util`、`quic_rpc`、`tokio`和`tracing`。
5//!
6//! ## 主要组件和功能
7//!
8//! - **导入和重导出**:使用`pub use`语句重导出了一些库中的类型和函数,以便在其他模块中直接使用。
9//! - **`GetServiceHandler` trait**:定义了一个trait,用于获取特定服务的处理程序。
10//! - **`ServiceHandler` trait**:定义了一个trait,用于服务的处理程序。
11//! - **`run_server` 函数**:用于运行一个RPC服务器。
12//! - **`ClientStreamingResponse` 结构体**:用于处理客户端流式RPC响应。
13//! - **`ServerStreamingResponse` 结构体**:用于处理服务器流式RPC响应。
14//!
15//! ## 注意事项
16//!
17//! - **异步处理**:代码中大量使用了异步处理,包括`async/await`关键字和`Future` trait。
18//! - **错误处理**:代码中使用了`Result`和`Error`类型进行错误处理。
19//! - **并发**:代码中使用了`Arc`和`LazyLock`来处理并发访问和延迟初始化。
20//! - **类型安全**:通过trait和泛型,代码确保了类型安全。
21//!
22//! ## 用途
23//!
24//! 这个模块给微服务框架导出了一些工具函数和结构体,可以处理不同类型的RPC请求,包括客户端流式和服务器流式请求。
25//! 它提供了基本的错误处理和并发支持,使得开发者可以更容易地构建高性能的RPC服务。
26
27mod error;
28mod transport;
29
30pub use error::{QuicRpcWrapError, Result};
31pub use flume::bounded as flume_bounded;
32use futures_lite::future::Boxed;
33pub use futures_lite::stream::{Stream, StreamExt};
34use futures_util::SinkExt;
35#[cfg(feature = "quinn")]
36pub use iroh_quinn::{
37    ClientConfig, Endpoint as QuinnEndpoint, ServerConfig,
38    crypto::rustls::{QuicClientConfig, QuicServerConfig},
39    rustls::{
40        RootCertStore,
41        pki_types::{CertificateDer, PrivatePkcs8KeyDer},
42        version::TLS13,
43    },
44};
45#[cfg(feature = "flume")]
46pub use quic_rpc::transport::flume::channel as flume_channel;
47#[cfg(feature = "hyper")]
48pub use quic_rpc::transport::hyper::{HyperConnector, HyperListener};
49#[cfg(feature = "quinn")]
50pub use quic_rpc::transport::quinn::{QuinnConnector, QuinnListener};
51pub use quic_rpc::{
52    Connector, Listener, RpcClient, RpcMessage, RpcServer, Service,
53    client::{BoxStreamSync, BoxedConnector, UpdateSink},
54    message::{
55        BidiStreaming, BidiStreamingMsg, ClientStreaming, ClientStreamingMsg, Msg, RpcMsg,
56        ServerStreaming, ServerStreamingMsg,
57    },
58    server::{BoxedChannelTypes, BoxedListener, ChannelTypes, RpcChannel},
59};
60#[cfg(feature = "quinn")]
61use std::{
62    fs::File,
63    io::{Read, Write},
64    path::Path,
65};
66use std::{
67    future::Future,
68    marker::PhantomData,
69    mem::replace,
70    pin::Pin,
71    sync::{Arc, LazyLock},
72};
73use tokio::runtime::Builder;
74pub use tokio::{pin, runtime::Runtime, sync::oneshot::channel as oneshot_channel};
75use tracing::{debug, error, warn};
76#[cfg(feature = "pipe")]
77pub use transport::pipe::{PipeConnector, PipeListener};
78#[cfg(feature = "iroh")]
79pub use {
80    iroh::{Endpoint as IrohEndpoint, NodeAddr, SecretKey},
81    quic_rpc::transport::iroh::{IrohConnector, IrohListener},
82};
83
84/// 获取特定服务的处理程序
85pub trait GetServiceHandler<S: Service> {
86    fn get_handler(self: Arc<Self>) -> Arc<S>;
87}
88
89/// 服务的处理程序
90pub trait ServiceHandler<S: Service, C: ChannelTypes<S> = BoxedChannelTypes<S>> {
91    /// 用于服务端处理请求和响应。
92    ///
93    /// # Arguments
94    ///
95    /// * `req`: 请求参数。
96    /// * `chan`: 连接通道。
97    /// * `rt`: 异步运行时。
98    ///
99    /// returns: impl Future<Output=Result<()>>+Send+Sized 是否处理成功。
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// None::<()>;
105    /// ```
106    fn handle_rpc_request(
107        self: Arc<Self>,
108        req: S::Req,
109        chan: RpcChannel<S, C>,
110        rt: &'static Runtime,
111    ) -> impl Future<Output = Result<()>> + Send;
112}
113
114pub static TIME: std::sync::LazyLock<std::time::Instant> =
115    std::sync::LazyLock::new(std::time::Instant::now);
116
117pub async fn run_server<S, L>(server: RpcServer<S, L>)
118where
119    L: Listener<S>,
120    S: Service + Default + ServiceHandler<S>,
121{
122    let service = Arc::new(S::default());
123    debug!("{:?}", service);
124    static RT: LazyLock<Runtime> =
125        LazyLock::new(|| Builder::new_multi_thread().enable_all().build().unwrap());
126    loop {
127        let Ok(accepting) = server.accept().await else {
128            continue;
129        };
130
131        match accepting.read_first().await {
132            Err(err) => warn!(?err, "server accept failed"),
133            Ok((req, chan)) => {
134                let handler = service.clone();
135                RT.spawn(async move {
136                    if let Err(err) =
137                        S::handle_rpc_request(handler, req, chan.map().boxed(), &*RT).await
138                    {
139                        warn!(?err, "internal rpc error");
140                    }
141                });
142            }
143        }
144    }
145}
146
147pub struct ClientStreamingResponse<T, S, C, R>(
148    Option<UpdateSink<C, T>>,
149    Boxed<Result<R>>,
150    PhantomData<S>,
151)
152where
153    S: Service,
154    C: Connector<S>,
155    T: Into<C::Out>;
156
157impl<T, S, C, R> ClientStreamingResponse<T, S, C, R>
158where
159    S: Service,
160    C: Connector<S>,
161    T: Into<C::Out>,
162{
163    pub fn new(
164        sink: UpdateSink<C, T>,
165        result: impl Future<Output = Result<R>> + Send + 'static,
166    ) -> Self {
167        Self(sink.into(), Box::pin(result) as _, Default::default())
168    }
169
170    pub async fn put(&mut self, item: T) -> &mut Self {
171        let Some(sink) = self.0.as_mut() else {
172            return self;
173        };
174        if let Err(e) = sink.send(item).await {
175            error!("Send data error. ({})", e);
176        }
177        self
178    }
179
180    pub async fn result(&mut self) -> Result<R> {
181        let Some(mut sink) = replace(&mut self.0, None) else {
182            return Err(QuicRpcWrapError::BadSink);
183        };
184        sink.close()
185            .await
186            .map_err(|e| QuicRpcWrapError::Send(e.to_string()))?;
187        drop(sink);
188        replace(
189            &mut self.1,
190            Box::pin(async { Err(QuicRpcWrapError::ResultAlreadyTakenAway) }) as _,
191        )
192        .await
193    }
194}
195
196pub struct ServerStreamingResponse<R>(Pin<Box<dyn Stream<Item = Result<R>> + Send>>);
197
198impl<R> ServerStreamingResponse<R> {
199    pub fn new(stream: impl Stream<Item = Result<R>> + Send + 'static) -> Self {
200        Self(Box::pin(stream) as _)
201    }
202
203    pub async fn next(&mut self) -> Option<Result<R>> {
204        self.0.next().await
205    }
206}
207
208unsafe impl<R> Send for ServerStreamingResponse<R> {}
209
210//noinspection SpellCheckingInspection
211/// 生成服务器证书和私钥。
212///
213/// 这个函数使用`rcgen`库生成一个自签名证书和对应的私钥。
214///
215/// # 参数
216///
217/// * `subject_alt_names` - 一个字符串切片,包含主题备用名称(Subject Alternative Names)。
218///
219/// # 返回值
220///
221/// 返回一个包含证书和私钥字节数组的元组。如果生成证书失败,将返回一个错误。
222///
223/// # 示例
224///
225/// ```rust
226/// use std::error::Error;
227///
228/// #[cfg(feature = "quinn")]
229/// fn main() -> Result<(), Box<dyn Error>> {
230///     let subject_alt_names = vec!["example.com", "localhost"];
231///     let (cert, key) = quic_rpc_utils::gen_server_cert(&subject_alt_names)?;
232///     println!("Certificate: {:?}", cert);
233///     println!("Private Key: {:?}", key);
234///     Ok(())
235/// }
236/// ```
237///
238/// # 注意
239///
240/// 1. 生成的证书仅用于开发和测试,不应在生产环境中使用。
241/// 2. 确保在编译时启用了`quinn`特性。
242///
243/// # 错误处理
244///
245/// 如果生成证书失败,将返回一个错误。
246#[cfg(feature = "quinn")]
247pub fn gen_server_cert(subject_alt_names: &[&str]) -> Result<(Vec<u8>, Vec<u8>)> {
248    let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed(
249        subject_alt_names
250            .iter()
251            .map(|i| i.to_string())
252            .collect::<Vec<_>>(),
253    )?;
254    let cert_der = cert.der();
255    let private_key = signing_key.serialize_der();
256    Ok((cert_der.to_vec(), private_key))
257}
258
259/// 将证书文件和私钥文件保存到指定的文件路径中。
260///
261/// # 参数
262///
263/// - `cert_der: &[u8]`:证书文件的DER编码格式数据。
264/// - `private_key: &[u8]`:私钥文件的数据。
265/// - `cert_der_file: impl AsRef<Path>`:证书文件保存路径。
266/// - `private_key_file: impl AsRef<Path>`:私钥文件保存路径。
267///
268/// # 返回值
269///
270/// - `Result<()>`:如果保存成功,返回`Ok(())`;如果保存失败,返回`Err`。
271///
272/// # 示例
273///
274/// ```rust
275/// use std::path::Path;
276/// use quic_rpc_utils::save_cert_file;
277///
278/// let cert_der = vec![0x30, 0x82, 0x01, 0x22, /* ... */];
279/// let private_key = vec![0x30, 0x82, 0x01, 0x22, /* ... */];
280/// let cert_der_file = Path::new("path/to/cert.crt");
281/// let private_key_file = Path::new("path/to/private_key.key");
282///
283/// if let Err(e) = save_cert_file(&cert_der, &private_key, cert_der_file, private_key_file) {
284///     eprintln!("保存文件失败: {}", e);
285/// }
286/// ```
287#[cfg(feature = "quinn")]
288pub fn save_cert_file(
289    cert_der: &[u8],
290    private_key: &[u8],
291    cert_der_file: impl AsRef<Path>,
292    private_key_file: impl AsRef<Path>,
293) -> Result<()> {
294    File::create(cert_der_file)?.write_all(cert_der)?;
295    File::create(private_key_file)?.write_all(private_key)?;
296    Ok(())
297}
298
299/// 从指定的文件路径中读取证书文件和私钥文件的数据。
300///
301/// # 参数
302///
303/// - `cert_der_file: impl AsRef<Path>`:证书文件路径。
304/// - `private_key_file: impl AsRef<Path>`:私钥文件路径。
305///
306/// # 返回值
307///
308/// - `Result<(Vec<u8>, Vec<u8>)>`:如果读取成功,返回包含证书文件和私钥文件数据的元组;如果读取失败,返回`Err`。
309///
310/// # 示例
311///
312/// ```rust
313/// use std::path::Path;
314/// use quic_rpc_utils::read_cert_file;
315///
316/// let cert_der_file = Path::new("path/to/cert.der");
317/// let private_key_file = Path::new("path/to/private_key.pem");
318///
319/// match read_cert_file(cert_der_file, private_key_file) {
320///     Ok((cert_der, private_key)) => {
321///         println!("读取的证书文件数据: {:?}", cert_der);
322///         println!("读取的私钥文件数据: {:?}", private_key);
323///     }
324///     Err(e) => {
325///         eprintln!("读取文件失败: {}", e);
326///     }
327/// }
328/// ```
329#[cfg(feature = "quinn")]
330pub fn read_cert_file(
331    cert_der_file: impl AsRef<Path>,
332    private_key_file: impl AsRef<Path>,
333) -> Result<(Vec<u8>, Vec<u8>)> {
334    let (mut cert_der, mut key) = Default::default();
335    File::open(cert_der_file)?.read_to_end(&mut cert_der)?;
336    File::open(private_key_file)?.read_to_end(&mut key)?;
337    Ok((cert_der, key))
338}
339
340/// 返回默认服务器配置
341#[cfg(feature = "quinn")]
342#[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527
343pub fn configure_server(
344    max_concurrent_uni_streams: u8,
345    cert_der: Vec<u8>,
346    private_key: Vec<u8>,
347) -> Result<ServerConfig> {
348    let private_key = PrivatePkcs8KeyDer::from(private_key);
349    let cert_chain = vec![CertificateDer::from(cert_der)];
350
351    let crypto_server_config = iroh_quinn::rustls::ServerConfig::builder_with_provider(Arc::new(
352        iroh_quinn::rustls::crypto::ring::default_provider(),
353    ))
354    .with_protocol_versions(&[&TLS13])
355    .expect("valid versions")
356    .with_no_client_auth()
357    .with_single_cert(cert_chain, private_key.into())?;
358    let quic_server_config = QuicServerConfig::try_from(crypto_server_config)?;
359    let mut server_config = ServerConfig::with_crypto(Arc::new(quic_server_config));
360
361    Arc::get_mut(&mut server_config.transport)
362        .unwrap()
363        .max_concurrent_uni_streams(max_concurrent_uni_streams.into());
364
365    Ok(server_config)
366}
367
368/// 构建默认的 quinn 客户端配置并信任给定的证书。
369///
370/// ## Args
371///
372/// - `server_certs`: DER 格式的受信任证书列表。
373#[cfg(feature = "quinn")]
374pub fn configure_client(server_certs: &[&[u8]]) -> Result<ClientConfig> {
375    let mut certs = RootCertStore::empty();
376    for cert in server_certs {
377        let cert = CertificateDer::from(cert.to_vec());
378        certs.add(cert)?;
379    }
380
381    let crypto_client_config = iroh_quinn::rustls::ClientConfig::builder_with_provider(Arc::new(
382        iroh_quinn::rustls::crypto::ring::default_provider(),
383    ))
384    .with_protocol_versions(&[&TLS13])
385    .expect("valid versions")
386    .with_root_certificates(certs)
387    .with_no_client_auth();
388    let quic_client_config = QuicClientConfig::try_from(crypto_client_config)?;
389
390    Ok(ClientConfig::new(Arc::new(quic_client_config)))
391}