llm_cost_ops_api/api/
server.rs1use axum::{middleware, Router};
4use serde::{Deserialize, Serialize};
5use std::net::SocketAddr;
6use std::time::Duration;
7use tower::ServiceBuilder;
8use tower_http::{
9 cors::CorsLayer,
10 timeout::TimeoutLayer,
11 trace::TraceLayer,
12};
13
14use super::{middleware as api_middleware, routes};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ApiServerConfig {
19 pub host: String,
21
22 pub port: u16,
24
25 pub request_timeout_secs: u64,
27
28 pub enable_cors: bool,
30
31 pub enable_logging: bool,
33}
34
35impl Default for ApiServerConfig {
36 fn default() -> Self {
37 Self {
38 host: "0.0.0.0".to_string(),
39 port: 8080,
40 request_timeout_secs: 30,
41 enable_cors: true,
42 enable_logging: true,
43 }
44 }
45}
46
47impl ApiServerConfig {
48 pub fn socket_addr(&self) -> Result<SocketAddr, String> {
50 format!("{}:{}", self.host, self.port)
51 .parse()
52 .map_err(|e| format!("Invalid socket address: {}", e))
53 }
54}
55
56pub struct ApiServer {
58 config: ApiServerConfig,
59}
60
61impl ApiServer {
62 pub fn new(config: ApiServerConfig) -> Self {
64 Self { config }
65 }
66
67 pub fn build_router(&self) -> Router {
69 create_api_router(&self.config)
70 }
71
72 pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
74 let addr = self.config.socket_addr()?;
75 let app = self.build_router();
76
77 tracing::info!("Starting API server on {}", addr);
78
79 let listener = tokio::net::TcpListener::bind(addr).await?;
80 axum::serve(listener, app).await?;
81
82 Ok(())
83 }
84}
85
86pub fn create_api_router(config: &ApiServerConfig) -> Router {
88 let mut router = routes::create_routes();
89
90 let middleware_stack = ServiceBuilder::new()
92 .layer(TimeoutLayer::new(Duration::from_secs(
93 config.request_timeout_secs,
94 )))
95 .layer(middleware::from_fn(api_middleware::request_id_middleware));
96
97 router = router.layer(middleware_stack);
98
99 if config.enable_cors {
101 router = router.layer(CorsLayer::permissive());
102 }
103
104 if config.enable_logging {
106 router = router.layer(TraceLayer::new_for_http());
107 }
108
109 router
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_default_config() {
118 let config = ApiServerConfig::default();
119 assert_eq!(config.host, "0.0.0.0");
120 assert_eq!(config.port, 8080);
121 assert!(config.enable_cors);
122 }
123
124 #[test]
125 fn test_socket_addr() {
126 let config = ApiServerConfig {
127 host: "127.0.0.1".to_string(),
128 port: 3000,
129 ..Default::default()
130 };
131
132 let addr = config.socket_addr().unwrap();
133 assert_eq!(addr.to_string(), "127.0.0.1:3000");
134 }
135
136 #[test]
137 fn test_server_creation() {
138 let config = ApiServerConfig::default();
139 let server = ApiServer::new(config);
140 let _router = server.build_router();
141 }
142}