1use crate::compatibility::CompatibilityLevel;
14use crate::error::SchemaError;
15use crate::registry::SchemaRegistry;
16use crate::types::{
17 SchemaContext, SchemaId, SchemaType, SchemaVersion, Subject, ValidationLevel, ValidationRule,
18 ValidationRuleType, VersionState,
19};
20use axum::{
21 extract::{Path, Query, State},
22 http::StatusCode,
23 routing::{delete, get, post, put},
24 Json, Router,
25};
26use parking_lot::RwLock;
27use serde::{Deserialize, Serialize};
28use std::net::SocketAddr;
29use std::sync::Arc;
30use tower_http::cors::{Any, CorsLayer};
31use tower_http::trace::TraceLayer;
32use tracing::info;
33
34#[cfg(feature = "auth")]
35use crate::auth::{auth_middleware, AuthConfig, ServerAuthState};
36#[cfg(feature = "auth")]
37use axum::middleware;
38#[cfg(feature = "auth")]
39use rivven_core::AuthManager;
40
41#[derive(Debug, Clone)]
43pub struct ServerConfig {
44 pub host: String,
45 pub port: u16,
46 #[cfg(feature = "auth")]
48 pub auth: Option<AuthConfig>,
49}
50
51impl Default for ServerConfig {
52 fn default() -> Self {
53 Self {
54 host: "0.0.0.0".to_string(),
55 port: 8081,
56 #[cfg(feature = "auth")]
57 auth: None,
58 }
59 }
60}
61
62impl ServerConfig {
63 #[cfg(feature = "auth")]
65 pub fn with_auth(mut self, auth_config: AuthConfig) -> Self {
66 self.auth = Some(auth_config);
67 self
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
73pub enum RegistryMode {
74 #[default]
75 ReadWrite,
76 ReadOnly,
77 Import,
78}
79
80impl std::fmt::Display for RegistryMode {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 RegistryMode::ReadWrite => write!(f, "READWRITE"),
84 RegistryMode::ReadOnly => write!(f, "READONLY"),
85 RegistryMode::Import => write!(f, "IMPORT"),
86 }
87 }
88}
89
90pub struct ServerState {
92 pub registry: SchemaRegistry,
93 pub mode: RwLock<RegistryMode>,
94 pub default_compatibility: RwLock<CompatibilityLevel>,
95}
96
97pub struct SchemaServer {
99 state: Arc<ServerState>,
100 #[allow(dead_code)] config: ServerConfig,
102 #[cfg(all(feature = "auth", not(feature = "cedar")))]
103 auth_manager: Option<Arc<AuthManager>>,
104 #[cfg(feature = "cedar")]
105 auth_manager: Option<Arc<AuthManager>>,
106 #[cfg(feature = "cedar")]
107 cedar_authorizer: Option<Arc<rivven_core::CedarAuthorizer>>,
108}
109
110impl SchemaServer {
111 pub fn new(registry: SchemaRegistry, config: ServerConfig) -> Self {
113 let default_compat = registry.get_default_compatibility();
114 Self {
115 state: Arc::new(ServerState {
116 registry,
117 mode: RwLock::new(RegistryMode::default()),
118 default_compatibility: RwLock::new(default_compat),
119 }),
120 config,
121 #[cfg(feature = "auth")]
122 auth_manager: None,
123 #[cfg(feature = "cedar")]
124 cedar_authorizer: None,
125 }
126 }
127
128 #[cfg(all(feature = "auth", not(feature = "cedar")))]
130 pub fn with_auth(
131 registry: SchemaRegistry,
132 config: ServerConfig,
133 auth_manager: Arc<AuthManager>,
134 ) -> Self {
135 let default_compat = registry.get_default_compatibility();
136 Self {
137 state: Arc::new(ServerState {
138 registry,
139 mode: RwLock::new(RegistryMode::default()),
140 default_compatibility: RwLock::new(default_compat),
141 }),
142 config,
143 auth_manager: Some(auth_manager),
144 }
145 }
146
147 #[cfg(feature = "cedar")]
149 pub fn with_auth(
150 registry: SchemaRegistry,
151 config: ServerConfig,
152 auth_manager: Arc<AuthManager>,
153 ) -> Self {
154 let default_compat = registry.get_default_compatibility();
155 Self {
156 state: Arc::new(ServerState {
157 registry,
158 mode: RwLock::new(RegistryMode::default()),
159 default_compatibility: RwLock::new(default_compat),
160 }),
161 config,
162 auth_manager: Some(auth_manager),
163 cedar_authorizer: None,
164 }
165 }
166
167 #[cfg(feature = "cedar")]
169 pub fn with_cedar(
170 registry: SchemaRegistry,
171 config: ServerConfig,
172 auth_manager: Arc<AuthManager>,
173 cedar_authorizer: Arc<rivven_core::CedarAuthorizer>,
174 ) -> Self {
175 let default_compat = registry.get_default_compatibility();
176 Self {
177 state: Arc::new(ServerState {
178 registry,
179 mode: RwLock::new(RegistryMode::default()),
180 default_compatibility: RwLock::new(default_compat),
181 }),
182 config,
183 auth_manager: Some(auth_manager),
184 cedar_authorizer: Some(cedar_authorizer),
185 }
186 }
187
188 pub fn router(&self) -> Router {
190 let cors = CorsLayer::new()
191 .allow_origin(Any)
192 .allow_methods(Any)
193 .allow_headers(Any);
194
195 let base_router = Router::new()
196 .route("/", get(root_handler))
198 .route("/health", get(health_handler))
199 .route("/health/live", get(liveness_handler))
200 .route("/health/ready", get(readiness_handler))
201 .route("/schemas/ids/:id", get(get_schema_by_id))
203 .route("/schemas/ids/:id/versions", get(get_schema_versions))
204 .route("/subjects", get(list_subjects))
206 .route("/subjects/:subject", delete(delete_subject))
207 .route("/subjects/:subject/versions", get(list_subject_versions))
208 .route("/subjects/:subject/versions", post(register_schema))
209 .route(
210 "/subjects/:subject/versions/:version",
211 get(get_subject_version),
212 )
213 .route(
214 "/subjects/:subject/versions/:version",
215 delete(delete_version),
216 )
217 .route("/subjects/deleted", get(list_deleted_subjects))
219 .route("/subjects/:subject/undelete", post(undelete_subject))
220 .route(
222 "/subjects/:subject/versions/:version/state",
223 get(get_version_state),
224 )
225 .route(
226 "/subjects/:subject/versions/:version/state",
227 put(set_version_state),
228 )
229 .route(
230 "/subjects/:subject/versions/:version/deprecate",
231 post(deprecate_version),
232 )
233 .route(
234 "/subjects/:subject/versions/:version/disable",
235 post(disable_version),
236 )
237 .route(
238 "/subjects/:subject/versions/:version/enable",
239 post(enable_version),
240 )
241 .route(
243 "/subjects/:subject/versions/:version/referencedby",
244 get(get_referenced_by),
245 )
246 .route("/subjects/:subject/validate", post(validate_schema))
248 .route(
250 "/compatibility/subjects/:subject/versions/:version",
251 post(check_compatibility),
252 )
253 .route("/config", get(get_global_config))
255 .route("/config", put(update_global_config))
256 .route("/config/:subject", get(get_subject_config))
257 .route("/config/:subject", put(update_subject_config))
258 .route("/config/validation/rules", get(list_validation_rules))
260 .route("/config/validation/rules", post(add_validation_rule))
261 .route(
262 "/config/validation/rules/:name",
263 delete(delete_validation_rule),
264 )
265 .route("/mode", get(get_mode))
267 .route("/mode", put(update_mode))
268 .route("/contexts", get(list_contexts))
270 .route("/contexts", post(create_context))
271 .route("/contexts/:context", get(get_context))
272 .route("/contexts/:context", delete(delete_context))
273 .route("/contexts/:context/subjects", get(list_context_subjects))
274 .route("/stats", get(get_stats))
276 .with_state(self.state.clone())
277 .layer(cors)
278 .layer(TraceLayer::new_for_http());
279
280 #[cfg(all(feature = "auth", not(feature = "cedar")))]
282 let base_router = if let Some(auth_manager) = &self.auth_manager {
283 let auth_config = self.config.auth.clone().unwrap_or_default();
284 let auth_state = Arc::new(ServerAuthState {
285 auth_manager: auth_manager.clone(),
286 config: auth_config,
287 });
288 base_router.layer(middleware::from_fn_with_state(auth_state, auth_middleware))
289 } else {
290 base_router
291 };
292
293 #[cfg(feature = "cedar")]
295 let base_router = if let Some(auth_manager) = &self.auth_manager {
296 let auth_config = self.config.auth.clone().unwrap_or_default();
297 let auth_state = Arc::new(ServerAuthState {
298 auth_manager: auth_manager.clone(),
299 cedar_authorizer: self.cedar_authorizer.clone(),
300 config: auth_config,
301 });
302 base_router.layer(middleware::from_fn_with_state(auth_state, auth_middleware))
303 } else {
304 base_router
305 };
306
307 base_router
308 }
309
310 pub async fn run(self, addr: SocketAddr) -> anyhow::Result<()> {
312 let router = self.router();
313 let listener = tokio::net::TcpListener::bind(addr).await?;
314 info!("Schema Registry server listening on {}", addr);
315 axum::serve(listener, router).await?;
316 Ok(())
317 }
318}
319
320fn schema_error_response(e: SchemaError) -> (StatusCode, Json<ErrorResponse>) {
325 let status = match e.http_status() {
326 404 => StatusCode::NOT_FOUND,
327 409 => StatusCode::CONFLICT,
328 422 => StatusCode::UNPROCESSABLE_ENTITY,
329 _ => StatusCode::INTERNAL_SERVER_ERROR,
330 };
331 (
332 status,
333 Json(ErrorResponse {
334 error_code: e.error_code(),
335 message: e.to_string(),
336 }),
337 )
338}
339
340#[derive(Serialize)]
345struct RootResponse {
346 version: &'static str,
347 commit: &'static str,
348}
349
350#[derive(Deserialize)]
351struct RegisterSchemaRequest {
352 schema: String,
353 #[serde(rename = "schemaType", default)]
354 schema_type: Option<String>,
355 #[serde(default)]
357 references: Vec<SchemaReference>,
358}
359
360#[derive(Deserialize, Serialize)]
361struct SchemaReference {
362 name: String,
363 subject: String,
364 version: u32,
365}
366
367#[derive(Serialize)]
368struct RegisterSchemaResponse {
369 id: u32,
370}
371
372#[derive(Serialize)]
373struct SchemaResponse {
374 schema: String,
375 #[serde(rename = "schemaType")]
376 schema_type: String,
377 #[serde(default, skip_serializing_if = "Vec::is_empty")]
378 references: Vec<SchemaReference>,
379}
380
381#[derive(Serialize)]
382struct SubjectVersionResponse {
383 subject: String,
384 version: u32,
385 id: u32,
386 schema: String,
387 #[serde(rename = "schemaType")]
388 schema_type: String,
389}
390
391#[derive(Deserialize)]
392struct CompatibilityRequest {
393 schema: String,
394 #[serde(rename = "schemaType", default)]
395 schema_type: Option<String>,
396}
397
398#[derive(Serialize)]
399struct CompatibilityResponse {
400 is_compatible: bool,
401 #[serde(skip_serializing_if = "Vec::is_empty")]
402 messages: Vec<String>,
403}
404
405#[derive(Serialize)]
406struct ConfigResponse {
407 #[serde(rename = "compatibilityLevel")]
408 compatibility_level: String,
409}
410
411#[derive(Deserialize)]
412struct ConfigRequest {
413 compatibility: String,
414}
415
416#[derive(Serialize)]
417struct ModeResponse {
418 mode: String,
419}
420
421#[derive(Deserialize)]
422struct ModeRequest {
423 mode: String,
424}
425
426#[derive(Serialize)]
427struct ErrorResponse {
428 error_code: u32,
429 message: String,
430}
431
432#[derive(Deserialize)]
433struct QueryParams {
434 #[serde(default)]
435 permanent: bool,
436}
437
438async fn root_handler() -> Json<RootResponse> {
443 Json(RootResponse {
444 version: env!("CARGO_PKG_VERSION"),
445 commit: "unknown",
446 })
447}
448
449#[derive(Serialize)]
451struct HealthResponse {
452 status: &'static str,
453 version: &'static str,
454}
455
456async fn health_handler() -> Json<HealthResponse> {
457 Json(HealthResponse {
458 status: "healthy",
459 version: env!("CARGO_PKG_VERSION"),
460 })
461}
462
463async fn liveness_handler() -> StatusCode {
465 StatusCode::OK
466}
467
468async fn readiness_handler(
470 State(state): State<Arc<ServerState>>,
471) -> Result<StatusCode, StatusCode> {
472 match state.registry.list_subjects().await {
474 Ok(_) => Ok(StatusCode::OK),
475 Err(_) => Err(StatusCode::SERVICE_UNAVAILABLE),
476 }
477}
478
479async fn get_schema_by_id(
480 State(state): State<Arc<ServerState>>,
481 Path(id): Path<u32>,
482) -> Result<Json<SchemaResponse>, (StatusCode, Json<ErrorResponse>)> {
483 state
484 .registry
485 .get_by_id(SchemaId::new(id))
486 .await
487 .map(|schema| {
488 Json(SchemaResponse {
489 schema: schema.schema,
490 schema_type: schema.schema_type.to_string(),
491 references: schema
492 .references
493 .into_iter()
494 .map(|r| SchemaReference {
495 name: r.name,
496 subject: r.subject,
497 version: r.version,
498 })
499 .collect(),
500 })
501 })
502 .map_err(schema_error_response)
503}
504
505async fn get_schema_versions(
506 State(state): State<Arc<ServerState>>,
507 Path(id): Path<u32>,
508) -> Result<Json<Vec<SubjectVersionResponse>>, (StatusCode, Json<ErrorResponse>)> {
509 let subjects = state
511 .registry
512 .list_subjects()
513 .await
514 .map_err(schema_error_response)?;
515 let mut results = Vec::new();
516
517 for subject in subjects {
518 let versions = state
519 .registry
520 .list_versions(subject.as_str())
521 .await
522 .map_err(schema_error_response)?;
523 for ver in versions {
524 if let Ok(sv) = state
525 .registry
526 .get_by_version(subject.as_str(), SchemaVersion::new(ver))
527 .await
528 {
529 if sv.id.0 == id {
530 results.push(SubjectVersionResponse {
531 subject: sv.subject.0,
532 version: sv.version.0,
533 id: sv.id.0,
534 schema: sv.schema,
535 schema_type: sv.schema_type.to_string(),
536 });
537 }
538 }
539 }
540 }
541
542 Ok(Json(results))
543}
544
545async fn list_subjects(
546 State(state): State<Arc<ServerState>>,
547) -> Result<Json<Vec<String>>, (StatusCode, Json<ErrorResponse>)> {
548 state
549 .registry
550 .list_subjects()
551 .await
552 .map(|subjects| Json(subjects.into_iter().map(|s| s.0).collect()))
553 .map_err(schema_error_response)
554}
555
556async fn delete_subject(
557 State(state): State<Arc<ServerState>>,
558 Path(subject): Path<String>,
559 Query(params): Query<QueryParams>,
560) -> Result<Json<Vec<u32>>, (StatusCode, Json<ErrorResponse>)> {
561 if *state.mode.read() == RegistryMode::ReadOnly {
563 return Err((
564 StatusCode::FORBIDDEN,
565 Json(ErrorResponse {
566 error_code: 40301,
567 message: "Registry is in READONLY mode".to_string(),
568 }),
569 ));
570 }
571
572 state
573 .registry
574 .delete_subject(subject.as_str(), params.permanent)
575 .await
576 .map(Json)
577 .map_err(schema_error_response)
578}
579
580async fn list_deleted_subjects(
582 State(state): State<Arc<ServerState>>,
583) -> Result<Json<Vec<String>>, (StatusCode, Json<ErrorResponse>)> {
584 state
585 .registry
586 .list_deleted_subjects()
587 .await
588 .map(|subjects| Json(subjects.into_iter().map(|s| s.0).collect()))
589 .map_err(schema_error_response)
590}
591
592async fn undelete_subject(
594 State(state): State<Arc<ServerState>>,
595 Path(subject): Path<String>,
596) -> Result<Json<Vec<u32>>, (StatusCode, Json<ErrorResponse>)> {
597 if *state.mode.read() == RegistryMode::ReadOnly {
599 return Err((
600 StatusCode::FORBIDDEN,
601 Json(ErrorResponse {
602 error_code: 40301,
603 message: "Registry is in READONLY mode".to_string(),
604 }),
605 ));
606 }
607
608 state
609 .registry
610 .undelete_subject(subject.as_str())
611 .await
612 .map(Json)
613 .map_err(schema_error_response)
614}
615
616async fn list_subject_versions(
617 State(state): State<Arc<ServerState>>,
618 Path(subject): Path<String>,
619) -> Result<Json<Vec<u32>>, (StatusCode, Json<ErrorResponse>)> {
620 state
621 .registry
622 .list_versions(subject.as_str())
623 .await
624 .map(Json)
625 .map_err(schema_error_response)
626}
627
628async fn register_schema(
629 State(state): State<Arc<ServerState>>,
630 Path(subject): Path<String>,
631 Json(req): Json<RegisterSchemaRequest>,
632) -> Result<Json<RegisterSchemaResponse>, (StatusCode, Json<ErrorResponse>)> {
633 if *state.mode.read() == RegistryMode::ReadOnly {
635 return Err((
636 StatusCode::FORBIDDEN,
637 Json(ErrorResponse {
638 error_code: 40301,
639 message: "Registry is in READONLY mode".to_string(),
640 }),
641 ));
642 }
643
644 let schema_type = req
645 .schema_type
646 .as_deref()
647 .unwrap_or("AVRO")
648 .parse::<SchemaType>()
649 .unwrap_or(SchemaType::Avro);
650
651 let references: Vec<crate::types::SchemaReference> = req
653 .references
654 .into_iter()
655 .map(|r| crate::types::SchemaReference {
656 name: r.name,
657 subject: r.subject,
658 version: r.version,
659 })
660 .collect();
661
662 state
663 .registry
664 .register_with_references(subject.as_str(), schema_type, &req.schema, references)
665 .await
666 .map(|id| Json(RegisterSchemaResponse { id: id.0 }))
667 .map_err(schema_error_response)
668}
669
670async fn get_subject_version(
671 State(state): State<Arc<ServerState>>,
672 Path((subject, version)): Path<(String, String)>,
673) -> Result<Json<SubjectVersionResponse>, (StatusCode, Json<ErrorResponse>)> {
674 let version = if version == "latest" {
675 state
676 .registry
677 .get_latest(subject.as_str())
678 .await
679 .map_err(schema_error_response)?
680 .version
681 } else {
682 SchemaVersion::new(version.parse().unwrap_or(0))
683 };
684
685 state
686 .registry
687 .get_by_version(subject.as_str(), version)
688 .await
689 .map(|sv| {
690 Json(SubjectVersionResponse {
691 subject: sv.subject.0,
692 version: sv.version.0,
693 id: sv.id.0,
694 schema: sv.schema,
695 schema_type: sv.schema_type.to_string(),
696 })
697 })
698 .map_err(schema_error_response)
699}
700
701async fn delete_version(
702 State(state): State<Arc<ServerState>>,
703 Path((subject, version)): Path<(String, u32)>,
704 Query(params): Query<QueryParams>,
705) -> Result<Json<u32>, (StatusCode, Json<ErrorResponse>)> {
706 if *state.mode.read() == RegistryMode::ReadOnly {
708 return Err((
709 StatusCode::FORBIDDEN,
710 Json(ErrorResponse {
711 error_code: 40301,
712 message: "Registry is in READONLY mode".to_string(),
713 }),
714 ));
715 }
716
717 state
718 .registry
719 .delete_version(
720 subject.as_str(),
721 SchemaVersion::new(version),
722 params.permanent,
723 )
724 .await
725 .map(|_| Json(version))
726 .map_err(schema_error_response)
727}
728
729async fn get_referenced_by(
731 State(state): State<Arc<ServerState>>,
732 Path((subject, version)): Path<(String, String)>,
733) -> Result<Json<Vec<u32>>, (StatusCode, Json<ErrorResponse>)> {
734 let version = if version == "latest" {
735 let versions = state
737 .registry
738 .list_versions(subject.as_str())
739 .await
740 .map_err(schema_error_response)?;
741 versions.into_iter().max().unwrap_or(1)
742 } else {
743 version.parse().unwrap_or(1)
744 };
745
746 state
747 .registry
748 .get_schemas_referencing(subject.as_str(), SchemaVersion::new(version))
749 .await
750 .map(|ids| Json(ids.into_iter().map(|id| id.0).collect()))
751 .map_err(schema_error_response)
752}
753
754async fn check_compatibility(
755 State(state): State<Arc<ServerState>>,
756 Path((subject, version)): Path<(String, String)>,
757 Json(req): Json<CompatibilityRequest>,
758) -> Result<Json<CompatibilityResponse>, (StatusCode, Json<ErrorResponse>)> {
759 let schema_type = req
760 .schema_type
761 .as_deref()
762 .unwrap_or("AVRO")
763 .parse::<SchemaType>()
764 .unwrap_or(SchemaType::Avro);
765
766 let version = if version == "latest" {
767 None
768 } else {
769 Some(SchemaVersion::new(version.parse().unwrap_or(0)))
770 };
771
772 state
773 .registry
774 .check_compatibility(subject.as_str(), schema_type, &req.schema, version)
775 .await
776 .map(|result| {
777 Json(CompatibilityResponse {
778 is_compatible: result.is_compatible,
779 messages: result.messages,
780 })
781 })
782 .map_err(schema_error_response)
783}
784
785async fn get_global_config(State(state): State<Arc<ServerState>>) -> Json<ConfigResponse> {
786 Json(ConfigResponse {
787 compatibility_level: state.default_compatibility.read().to_string(),
788 })
789}
790
791async fn update_global_config(
792 State(state): State<Arc<ServerState>>,
793 Json(req): Json<ConfigRequest>,
794) -> Result<Json<ConfigResponse>, (StatusCode, Json<ErrorResponse>)> {
795 let level: CompatibilityLevel = req.compatibility.parse().map_err(|_| {
796 (
797 StatusCode::UNPROCESSABLE_ENTITY,
798 Json(ErrorResponse {
799 error_code: 42203,
800 message: format!("Invalid compatibility level: {}", req.compatibility),
801 }),
802 )
803 })?;
804
805 *state.default_compatibility.write() = level;
806
807 Ok(Json(ConfigResponse {
808 compatibility_level: level.to_string(),
809 }))
810}
811
812async fn get_subject_config(
813 State(state): State<Arc<ServerState>>,
814 Path(subject): Path<String>,
815) -> Json<ConfigResponse> {
816 let level = state
817 .registry
818 .get_subject_compatibility(&Subject::new(subject));
819 Json(ConfigResponse {
820 compatibility_level: level.to_string(),
821 })
822}
823
824async fn update_subject_config(
825 State(state): State<Arc<ServerState>>,
826 Path(subject): Path<String>,
827 Json(req): Json<ConfigRequest>,
828) -> Result<Json<ConfigResponse>, (StatusCode, Json<ErrorResponse>)> {
829 let level: CompatibilityLevel = req.compatibility.parse().map_err(|_| {
830 (
831 StatusCode::UNPROCESSABLE_ENTITY,
832 Json(ErrorResponse {
833 error_code: 42203,
834 message: format!("Invalid compatibility level: {}", req.compatibility),
835 }),
836 )
837 })?;
838
839 state
840 .registry
841 .set_subject_compatibility(subject.as_str(), level);
842
843 Ok(Json(ConfigResponse {
844 compatibility_level: level.to_string(),
845 }))
846}
847
848async fn get_mode(State(state): State<Arc<ServerState>>) -> Json<ModeResponse> {
849 Json(ModeResponse {
850 mode: state.mode.read().to_string(),
851 })
852}
853
854async fn update_mode(
855 State(state): State<Arc<ServerState>>,
856 Json(req): Json<ModeRequest>,
857) -> Result<Json<ModeResponse>, (StatusCode, Json<ErrorResponse>)> {
858 let mode = match req.mode.to_uppercase().as_str() {
859 "READWRITE" => RegistryMode::ReadWrite,
860 "READONLY" => RegistryMode::ReadOnly,
861 "IMPORT" => RegistryMode::Import,
862 _ => {
863 return Err((
864 StatusCode::UNPROCESSABLE_ENTITY,
865 Json(ErrorResponse {
866 error_code: 42204,
867 message: format!("Invalid mode: {}", req.mode),
868 }),
869 ))
870 }
871 };
872
873 *state.mode.write() = mode;
874
875 Ok(Json(ModeResponse {
876 mode: mode.to_string(),
877 }))
878}
879
880#[derive(Serialize)]
885struct VersionStateResponse {
886 state: String,
887}
888
889#[derive(Deserialize)]
890struct VersionStateRequest {
891 state: String,
892}
893
894async fn get_version_state(
895 State(state): State<Arc<ServerState>>,
896 Path((subject, version)): Path<(String, u32)>,
897) -> Result<Json<VersionStateResponse>, (StatusCode, Json<ErrorResponse>)> {
898 state
899 .registry
900 .get_version_state(subject.as_str(), SchemaVersion::new(version))
901 .await
902 .map(|s| {
903 Json(VersionStateResponse {
904 state: s.to_string(),
905 })
906 })
907 .map_err(schema_error_response)
908}
909
910async fn set_version_state(
911 State(state): State<Arc<ServerState>>,
912 Path((subject, version)): Path<(String, u32)>,
913 Json(req): Json<VersionStateRequest>,
914) -> Result<Json<VersionStateResponse>, (StatusCode, Json<ErrorResponse>)> {
915 let version_state: VersionState = req.state.parse().map_err(|_| {
916 (
917 StatusCode::UNPROCESSABLE_ENTITY,
918 Json(ErrorResponse {
919 error_code: 42205,
920 message: format!("Invalid version state: {}", req.state),
921 }),
922 )
923 })?;
924
925 state
926 .registry
927 .set_version_state(subject.as_str(), SchemaVersion::new(version), version_state)
928 .await
929 .map_err(schema_error_response)?;
930
931 Ok(Json(VersionStateResponse {
932 state: version_state.to_string(),
933 }))
934}
935
936async fn deprecate_version(
937 State(state): State<Arc<ServerState>>,
938 Path((subject, version)): Path<(String, u32)>,
939) -> Result<Json<VersionStateResponse>, (StatusCode, Json<ErrorResponse>)> {
940 state
941 .registry
942 .deprecate_version(subject.as_str(), SchemaVersion::new(version))
943 .await
944 .map_err(schema_error_response)?;
945
946 Ok(Json(VersionStateResponse {
947 state: "DEPRECATED".to_string(),
948 }))
949}
950
951async fn disable_version(
952 State(state): State<Arc<ServerState>>,
953 Path((subject, version)): Path<(String, u32)>,
954) -> Result<Json<VersionStateResponse>, (StatusCode, Json<ErrorResponse>)> {
955 state
956 .registry
957 .disable_version(subject.as_str(), SchemaVersion::new(version))
958 .await
959 .map_err(schema_error_response)?;
960
961 Ok(Json(VersionStateResponse {
962 state: "DISABLED".to_string(),
963 }))
964}
965
966async fn enable_version(
967 State(state): State<Arc<ServerState>>,
968 Path((subject, version)): Path<(String, u32)>,
969) -> Result<Json<VersionStateResponse>, (StatusCode, Json<ErrorResponse>)> {
970 state
971 .registry
972 .enable_version(subject.as_str(), SchemaVersion::new(version))
973 .await
974 .map_err(schema_error_response)?;
975
976 Ok(Json(VersionStateResponse {
977 state: "ENABLED".to_string(),
978 }))
979}
980
981#[derive(Deserialize)]
986struct ValidateSchemaRequest {
987 schema: String,
988 #[serde(rename = "schemaType", default)]
989 schema_type: Option<String>,
990}
991
992#[derive(Serialize)]
993struct ValidationResponse {
994 is_valid: bool,
995 errors: Vec<String>,
996 warnings: Vec<String>,
997}
998
999async fn validate_schema(
1000 State(state): State<Arc<ServerState>>,
1001 Path(subject): Path<String>,
1002 Json(req): Json<ValidateSchemaRequest>,
1003) -> Result<Json<ValidationResponse>, (StatusCode, Json<ErrorResponse>)> {
1004 let schema_type = req
1005 .schema_type
1006 .as_deref()
1007 .unwrap_or("AVRO")
1008 .parse::<SchemaType>()
1009 .unwrap_or(SchemaType::Avro);
1010
1011 state
1012 .registry
1013 .validate_schema(schema_type, &subject, &req.schema)
1014 .map(|report| {
1015 Json(ValidationResponse {
1016 is_valid: report.is_valid(),
1017 errors: report.error_messages(),
1018 warnings: report.warning_messages(),
1019 })
1020 })
1021 .map_err(schema_error_response)
1022}
1023
1024#[derive(Serialize)]
1025struct ValidationRuleResponse {
1026 name: String,
1027 rule_type: String,
1028 config: String,
1029 level: String,
1030 #[serde(skip_serializing_if = "Option::is_none")]
1031 description: Option<String>,
1032 active: bool,
1033}
1034
1035#[derive(Deserialize)]
1036struct AddValidationRuleRequest {
1037 name: String,
1038 rule_type: String,
1039 config: String,
1040 #[serde(default = "default_level")]
1041 level: String,
1042 description: Option<String>,
1043 #[serde(default)]
1044 subjects: Vec<String>,
1045 #[serde(default)]
1046 schema_types: Vec<String>,
1047}
1048
1049fn default_level() -> String {
1050 "ERROR".to_string()
1051}
1052
1053async fn list_validation_rules(
1054 State(state): State<Arc<ServerState>>,
1055) -> Json<Vec<ValidationRuleResponse>> {
1056 let rules = state.registry.validation_engine().read().list_rules();
1057 Json(
1058 rules
1059 .into_iter()
1060 .map(|r| ValidationRuleResponse {
1061 name: r.name().to_string(),
1062 rule_type: format!("{:?}", r.rule_type()),
1063 config: r.config().to_string(),
1064 level: format!("{:?}", r.level()),
1065 description: r.description().map(|s| s.to_string()),
1066 active: r.is_active(),
1067 })
1068 .collect(),
1069 )
1070}
1071
1072async fn add_validation_rule(
1073 State(state): State<Arc<ServerState>>,
1074 Json(req): Json<AddValidationRuleRequest>,
1075) -> Result<Json<ValidationRuleResponse>, (StatusCode, Json<ErrorResponse>)> {
1076 let rule_type: ValidationRuleType = req.rule_type.parse().map_err(|_| {
1077 (
1078 StatusCode::UNPROCESSABLE_ENTITY,
1079 Json(ErrorResponse {
1080 error_code: 42206,
1081 message: format!("Invalid rule type: {}", req.rule_type),
1082 }),
1083 )
1084 })?;
1085
1086 let level: ValidationLevel = req.level.parse().map_err(|_| {
1087 (
1088 StatusCode::UNPROCESSABLE_ENTITY,
1089 Json(ErrorResponse {
1090 error_code: 42207,
1091 message: format!("Invalid validation level: {}", req.level),
1092 }),
1093 )
1094 })?;
1095
1096 let mut rule = ValidationRule::new(&req.name, rule_type, &req.config).with_level(level);
1097
1098 if let Some(desc) = &req.description {
1099 rule = rule.with_description(desc);
1100 }
1101
1102 if !req.subjects.is_empty() {
1103 rule = rule.for_subjects(req.subjects);
1104 }
1105
1106 if !req.schema_types.is_empty() {
1107 let schema_types: Vec<SchemaType> = req
1108 .schema_types
1109 .iter()
1110 .filter_map(|s| s.parse().ok())
1111 .collect();
1112 rule = rule.for_schema_types(schema_types);
1113 }
1114
1115 state.registry.add_validation_rule(rule.clone());
1116
1117 Ok(Json(ValidationRuleResponse {
1118 name: rule.name().to_string(),
1119 rule_type: format!("{:?}", rule.rule_type()),
1120 config: rule.config().to_string(),
1121 level: format!("{:?}", rule.level()),
1122 description: rule.description().map(|s| s.to_string()),
1123 active: rule.is_active(),
1124 }))
1125}
1126
1127async fn delete_validation_rule(
1128 State(state): State<Arc<ServerState>>,
1129 Path(name): Path<String>,
1130) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
1131 let removed = state
1132 .registry
1133 .validation_engine()
1134 .write()
1135 .remove_rule(&name);
1136 if removed {
1137 Ok(StatusCode::NO_CONTENT)
1138 } else {
1139 Err((
1140 StatusCode::NOT_FOUND,
1141 Json(ErrorResponse {
1142 error_code: 40404,
1143 message: format!("Validation rule not found: {}", name),
1144 }),
1145 ))
1146 }
1147}
1148
1149#[derive(Serialize)]
1154struct ContextResponse {
1155 name: String,
1156 #[serde(skip_serializing_if = "Option::is_none")]
1157 description: Option<String>,
1158 active: bool,
1159}
1160
1161#[derive(Deserialize)]
1162struct CreateContextRequest {
1163 name: String,
1164 description: Option<String>,
1165}
1166
1167async fn list_contexts(State(state): State<Arc<ServerState>>) -> Json<Vec<ContextResponse>> {
1168 let contexts = state.registry.list_contexts();
1169 Json(
1170 contexts
1171 .into_iter()
1172 .map(|c| ContextResponse {
1173 name: c.name().to_string(),
1174 description: c.description().map(|s| s.to_string()),
1175 active: c.is_active(),
1176 })
1177 .collect(),
1178 )
1179}
1180
1181async fn create_context(
1182 State(state): State<Arc<ServerState>>,
1183 Json(req): Json<CreateContextRequest>,
1184) -> Result<Json<ContextResponse>, (StatusCode, Json<ErrorResponse>)> {
1185 let mut context = SchemaContext::new(&req.name);
1186 if let Some(desc) = &req.description {
1187 context = context.with_description(desc);
1188 }
1189
1190 state
1191 .registry
1192 .create_context(context.clone())
1193 .map_err(schema_error_response)?;
1194
1195 Ok(Json(ContextResponse {
1196 name: context.name().to_string(),
1197 description: context.description().map(|s| s.to_string()),
1198 active: context.is_active(),
1199 }))
1200}
1201
1202async fn get_context(
1203 State(state): State<Arc<ServerState>>,
1204 Path(context_name): Path<String>,
1205) -> Result<Json<ContextResponse>, (StatusCode, Json<ErrorResponse>)> {
1206 state
1207 .registry
1208 .get_context(&context_name)
1209 .map(|c| {
1210 Json(ContextResponse {
1211 name: c.name().to_string(),
1212 description: c.description().map(|s| s.to_string()),
1213 active: c.is_active(),
1214 })
1215 })
1216 .ok_or_else(|| {
1217 (
1218 StatusCode::NOT_FOUND,
1219 Json(ErrorResponse {
1220 error_code: 40405,
1221 message: format!("Context not found: {}", context_name),
1222 }),
1223 )
1224 })
1225}
1226
1227async fn delete_context(
1228 State(state): State<Arc<ServerState>>,
1229 Path(context_name): Path<String>,
1230) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
1231 state
1232 .registry
1233 .delete_context(&context_name)
1234 .map(|_| StatusCode::NO_CONTENT)
1235 .map_err(schema_error_response)
1236}
1237
1238async fn list_context_subjects(
1239 State(state): State<Arc<ServerState>>,
1240 Path(context_name): Path<String>,
1241) -> Json<Vec<String>> {
1242 Json(state.registry.list_subjects_in_context(&context_name))
1243}
1244
1245#[derive(Serialize)]
1250struct StatsResponse {
1251 schema_count: usize,
1252 subject_count: usize,
1253 context_count: usize,
1254 cache_size: usize,
1255}
1256
1257async fn get_stats(State(state): State<Arc<ServerState>>) -> Json<StatsResponse> {
1258 let stats = state.registry.stats().await;
1259 Json(StatsResponse {
1260 schema_count: stats.schema_count,
1261 subject_count: stats.subject_count,
1262 context_count: stats.context_count,
1263 cache_size: stats.cache_size,
1264 })
1265}
1266
1267#[cfg(test)]
1268mod tests {
1269 use super::*;
1270 use crate::config::RegistryConfig;
1271 use axum::body::Body;
1272 use axum::http::Request;
1273 use tower::util::ServiceExt;
1274
1275 async fn create_test_app() -> Router {
1276 let config = RegistryConfig::memory();
1277 let registry = SchemaRegistry::new(config).await.unwrap();
1278 let server = SchemaServer::new(registry, ServerConfig::default());
1279 server.router()
1280 }
1281
1282 #[tokio::test]
1283 async fn test_root_endpoint() {
1284 let app = create_test_app().await;
1285
1286 let response = app
1287 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
1288 .await
1289 .unwrap();
1290
1291 assert_eq!(response.status(), StatusCode::OK);
1292 }
1293
1294 #[tokio::test]
1295 async fn test_list_subjects_empty() {
1296 let app = create_test_app().await;
1297
1298 let response = app
1299 .oneshot(
1300 Request::builder()
1301 .uri("/subjects")
1302 .body(Body::empty())
1303 .unwrap(),
1304 )
1305 .await
1306 .unwrap();
1307
1308 assert_eq!(response.status(), StatusCode::OK);
1309 }
1310
1311 #[tokio::test]
1312 async fn test_get_config() {
1313 let app = create_test_app().await;
1314
1315 let response = app
1316 .oneshot(
1317 Request::builder()
1318 .uri("/config")
1319 .body(Body::empty())
1320 .unwrap(),
1321 )
1322 .await
1323 .unwrap();
1324
1325 assert_eq!(response.status(), StatusCode::OK);
1326 }
1327
1328 #[tokio::test]
1329 async fn test_get_mode() {
1330 let app = create_test_app().await;
1331
1332 let response = app
1333 .oneshot(Request::builder().uri("/mode").body(Body::empty()).unwrap())
1334 .await
1335 .unwrap();
1336
1337 assert_eq!(response.status(), StatusCode::OK);
1338 }
1339
1340 #[tokio::test]
1341 async fn test_health_endpoint() {
1342 let app = create_test_app().await;
1343
1344 let response = app
1345 .oneshot(
1346 Request::builder()
1347 .uri("/health")
1348 .body(Body::empty())
1349 .unwrap(),
1350 )
1351 .await
1352 .unwrap();
1353
1354 assert_eq!(response.status(), StatusCode::OK);
1355 }
1356
1357 #[tokio::test]
1358 async fn test_liveness_endpoint() {
1359 let app = create_test_app().await;
1360
1361 let response = app
1362 .oneshot(
1363 Request::builder()
1364 .uri("/health/live")
1365 .body(Body::empty())
1366 .unwrap(),
1367 )
1368 .await
1369 .unwrap();
1370
1371 assert_eq!(response.status(), StatusCode::OK);
1372 }
1373
1374 #[tokio::test]
1375 async fn test_readiness_endpoint() {
1376 let app = create_test_app().await;
1377
1378 let response = app
1379 .oneshot(
1380 Request::builder()
1381 .uri("/health/ready")
1382 .body(Body::empty())
1383 .unwrap(),
1384 )
1385 .await
1386 .unwrap();
1387
1388 assert_eq!(response.status(), StatusCode::OK);
1389 }
1390
1391 #[tokio::test]
1392 async fn test_stats_endpoint() {
1393 let app = create_test_app().await;
1394
1395 let response = app
1396 .oneshot(
1397 Request::builder()
1398 .uri("/stats")
1399 .body(Body::empty())
1400 .unwrap(),
1401 )
1402 .await
1403 .unwrap();
1404
1405 assert_eq!(response.status(), StatusCode::OK);
1406 }
1407
1408 #[tokio::test]
1409 async fn test_list_contexts_empty() {
1410 let app = create_test_app().await;
1411
1412 let response = app
1413 .oneshot(
1414 Request::builder()
1415 .uri("/contexts")
1416 .body(Body::empty())
1417 .unwrap(),
1418 )
1419 .await
1420 .unwrap();
1421
1422 assert_eq!(response.status(), StatusCode::OK);
1423 }
1424
1425 #[tokio::test]
1426 async fn test_list_validation_rules_empty() {
1427 let app = create_test_app().await;
1428
1429 let response = app
1430 .oneshot(
1431 Request::builder()
1432 .uri("/config/validation/rules")
1433 .body(Body::empty())
1434 .unwrap(),
1435 )
1436 .await
1437 .unwrap();
1438
1439 assert_eq!(response.status(), StatusCode::OK);
1440 }
1441}