1use osproxy_core::FieldName;
10use osproxy_rewrite::{construct_id_bytes, inject_fields_bytes, validate_json};
11use osproxy_sink::{DocOp, WriteBatch, WriteOp};
12use osproxy_spi::{BodyTransform, DocIdRule, InjectedField, InjectedValue};
13use osproxy_tenancy::Resolved;
14use serde_json::Value;
15
16use crate::error::RequestError;
17
18pub fn build_write_batch(resolved: &Resolved, body: &[u8]) -> Result<WriteBatch, RequestError> {
32 let decision = &resolved.decision;
33 let partition = resolved.partition.as_str();
34
35 let (id, out_body) = apply_transform(body, &decision.body_transform, partition)?;
36 let routing = routing_for(&decision.body_transform, partition);
37
38 let op = WriteOp::new(
39 decision.target.clone(),
40 DocOp::Index {
41 id,
42 routing,
43 body: out_body,
44 },
45 decision.epoch,
46 )
47 .with_protocol(decision.upstream_protocol);
48 Ok(WriteBatch::single(op))
49}
50
51fn apply_transform(
54 body: &[u8],
55 transform: &BodyTransform,
56 partition: &str,
57) -> Result<(Option<String>, Vec<u8>), RequestError> {
58 match transform {
59 BodyTransform::None => {
61 validate_json(body).map_err(RequestError::from)?;
62 Ok((None, body.to_vec()))
63 }
64 BodyTransform::Inject(fields) => Ok((None, inject(body, fields, partition)?)),
65 BodyTransform::ConstructId(rule) => {
67 validate_json(body).map_err(RequestError::from)?;
68 Ok((Some(build_id(rule, body, partition)?), body.to_vec()))
69 }
70 BodyTransform::Both { inject: fields, id } => {
73 let out = inject(body, fields, partition)?;
74 Ok((Some(build_id(id, body, partition)?), out))
75 }
76 }
77}
78
79fn inject(body: &[u8], fields: &[InjectedField], partition: &str) -> Result<Vec<u8>, RequestError> {
81 let resolved = resolve_values(fields, partition)?;
82 inject_fields_bytes(body, &resolved).map_err(RequestError::from)
83}
84
85fn build_id(rule: &DocIdRule, body: &[u8], partition: &str) -> Result<String, RequestError> {
87 construct_id_bytes(rule.template.as_str(), partition, body).map_err(RequestError::from)
88}
89
90fn routing_for(transform: &BodyTransform, partition: &str) -> Option<String> {
93 let rule = match transform {
94 BodyTransform::ConstructId(rule) | BodyTransform::Both { id: rule, .. } => Some(rule),
95 BodyTransform::None | BodyTransform::Inject(_) => None,
96 };
97 rule.filter(|r| r.set_routing).map(|_| partition.to_owned())
98}
99
100fn resolve_values(
107 fields: &[InjectedField],
108 partition: &str,
109) -> Result<Vec<(FieldName, Value)>, RequestError> {
110 fields
111 .iter()
112 .map(|field| {
113 let value = match &field.value {
114 InjectedValue::Constant(v) => v.clone(),
115 InjectedValue::PartitionId => Value::String(partition.to_owned()),
116 InjectedValue::FromPrincipal(_) | InjectedValue::FromHeader(_) => {
117 return Err(RequestError::Internal {
118 reason: "context-derived injected value reached the engine unresolved",
119 })
120 }
121 };
122 Ok((field.name.clone(), value))
123 })
124 .collect()
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use osproxy_core::{ClusterId, Epoch, IndexName, PartitionId, Target};
131 use osproxy_rewrite::RewriteError;
132 use osproxy_spi::{IdTemplate, Protocol, RouteDecision};
133 use serde_json::json;
134
135 fn resolved(transform: BodyTransform) -> Resolved {
136 Resolved {
137 partition: PartitionId::from("acme"),
138 decision: RouteDecision {
139 target: Target::new(ClusterId::from("eu-1"), IndexName::from("shared")),
140 upstream_protocol: Protocol::Http1,
141 header_ops: Vec::new(),
142 body_transform: transform,
143 epoch: Epoch::new(4),
144 },
145 migration: osproxy_spi::MigrationPhase::Settled,
146 }
147 }
148
149 fn index_op(batch: &WriteBatch) -> (&Option<String>, &Option<String>, Value) {
150 match &batch.ops()[0].doc {
151 DocOp::Index { id, routing, body } => {
152 (id, routing, serde_json::from_slice(body).unwrap())
153 }
154 DocOp::Create { .. } | DocOp::Update { .. } | DocOp::Delete { .. } => {
156 unreachable_delete()
157 }
158 }
159 }
160
161 fn unreachable_delete() -> (&'static Option<String>, &'static Option<String>, Value) {
164 (&None, &None, Value::Null)
165 }
166
167 #[test]
168 fn inject_and_construct_id_with_routing() {
169 let inject = vec![InjectedField::new(
170 FieldName::from("_tenant"),
171 InjectedValue::PartitionId,
172 )];
173 let id = DocIdRule::new(IdTemplate::new("{partition}:{body.id}")).with_routing(true);
174 let r = resolved(BodyTransform::Both { inject, id });
175 let batch = build_write_batch(&r, br#"{ "id": 1001, "msg": "hi" }"#).unwrap();
176
177 assert_eq!(batch.ops()[0].epoch, Epoch::new(4));
178 let (id, routing, body) = index_op(&batch);
179 assert_eq!(id.as_deref(), Some("acme:1001"));
180 assert_eq!(routing.as_deref(), Some("acme"));
181 assert_eq!(body["_tenant"], json!("acme"));
182 assert_eq!(body["msg"], json!("hi"));
183 }
184
185 #[test]
186 fn inject_only_has_no_id_or_routing() {
187 let inject = vec![InjectedField::new(
188 FieldName::from("_t"),
189 InjectedValue::Constant(json!("acme")),
190 )];
191 let r = resolved(BodyTransform::Inject(inject));
192 let batch = build_write_batch(&r, br#"{ "k": 1 }"#).unwrap();
193 let (id, routing, body) = index_op(&batch);
194 assert!(id.is_none());
195 assert!(routing.is_none());
196 assert_eq!(body["_t"], json!("acme"));
197 }
198
199 #[test]
200 fn construct_id_without_routing() {
201 let id = DocIdRule::new(IdTemplate::new("{partition}:{body.k}"));
202 let r = resolved(BodyTransform::ConstructId(id));
203 let batch = build_write_batch(&r, br#"{ "k": "v" }"#).unwrap();
204 let (id, routing, _) = index_op(&batch);
205 assert_eq!(id.as_deref(), Some("acme:v"));
206 assert!(routing.is_none());
207 }
208
209 #[test]
210 fn none_transform_passes_body_through() {
211 let r = resolved(BodyTransform::None);
212 let batch = build_write_batch(&r, br#"{ "k": 1 }"#).unwrap();
213 let (id, routing, body) = index_op(&batch);
214 assert!(id.is_none());
215 assert!(routing.is_none());
216 assert_eq!(body, json!({ "k": 1 }));
217 }
218
219 #[test]
220 fn reserved_field_collision_is_rejected() {
221 let inject = vec![InjectedField::new(
222 FieldName::from("_t"),
223 InjectedValue::Constant(json!("acme")),
224 )];
225 let r = resolved(BodyTransform::Inject(inject));
226 let err = build_write_batch(&r, br#"{ "_t": "evil" }"#).unwrap_err();
227 assert!(matches!(
228 err,
229 RequestError::Rewrite(RewriteError::ReservedFieldCollision { .. })
230 ));
231 }
232
233 #[test]
234 fn malformed_body_is_rejected() {
235 let r = resolved(BodyTransform::None);
236 let err = build_write_batch(&r, b"not json").unwrap_err();
237 assert!(matches!(
238 err,
239 RequestError::Rewrite(RewriteError::InvalidJson)
240 ));
241 }
242
243 #[test]
244 fn unresolved_principal_value_is_internal_error() {
245 let inject = vec![InjectedField::new(
246 FieldName::from("_t"),
247 InjectedValue::FromPrincipal("tenant".to_owned()),
248 )];
249 let r = resolved(BodyTransform::Inject(inject));
250 let err = build_write_batch(&r, br#"{ "k": 1 }"#).unwrap_err();
251 assert!(matches!(err, RequestError::Internal { .. }));
252 }
253}