1use bytes::Bytes;
10use osproxy_core::FieldName;
11use osproxy_rewrite::{construct_id_bytes, inject_fields_bytes, validate_json};
12use osproxy_sink::{DocOp, WriteBatch, WriteOp};
13use osproxy_spi::{BodyTransform, DocIdRule, InjectedField, InjectedValue};
14use osproxy_tenancy::Resolved;
15use serde_json::Value;
16
17use crate::error::RequestError;
18
19pub fn build_write_batch(resolved: &Resolved, body: &[u8]) -> Result<WriteBatch, RequestError> {
33 let decision = &resolved.decision;
34 let partition = resolved.partition.as_str();
35
36 let (id, out_body) = apply_transform(body, &decision.body_transform, partition)?;
37 let routing = routing_for(&decision.body_transform, partition);
38
39 let op = WriteOp::new(
40 decision.target.clone(),
41 DocOp::Index {
42 id,
43 routing,
44 body: out_body,
45 },
46 decision.epoch,
47 )
48 .with_protocol(decision.upstream_protocol);
49 Ok(WriteBatch::single(op))
50}
51
52fn apply_transform(
55 body: &[u8],
56 transform: &BodyTransform,
57 partition: &str,
58) -> Result<(Option<String>, Bytes), RequestError> {
59 match transform {
60 BodyTransform::None => {
62 validate_json(body).map_err(RequestError::from)?;
63 Ok((None, Bytes::copy_from_slice(body)))
64 }
65 BodyTransform::Inject(fields) => Ok((None, Bytes::from(inject(body, fields, partition)?))),
66 BodyTransform::ConstructId(rule) => {
68 validate_json(body).map_err(RequestError::from)?;
69 Ok((
70 Some(build_id(rule, body, partition)?),
71 Bytes::copy_from_slice(body),
72 ))
73 }
74 BodyTransform::Both { inject: fields, id } => {
77 let out = Bytes::from(inject(body, fields, partition)?);
79 Ok((Some(build_id(id, body, partition)?), out))
80 }
81 }
82}
83
84fn inject(body: &[u8], fields: &[InjectedField], partition: &str) -> Result<Vec<u8>, RequestError> {
86 let resolved = resolve_values(fields, partition)?;
87 inject_fields_bytes(body, &resolved).map_err(RequestError::from)
88}
89
90fn build_id(rule: &DocIdRule, body: &[u8], partition: &str) -> Result<String, RequestError> {
92 construct_id_bytes(rule.template.as_str(), partition, body).map_err(RequestError::from)
93}
94
95fn routing_for(transform: &BodyTransform, partition: &str) -> Option<String> {
98 let rule = match transform {
99 BodyTransform::ConstructId(rule) | BodyTransform::Both { id: rule, .. } => Some(rule),
100 BodyTransform::None | BodyTransform::Inject(_) => None,
101 };
102 rule.filter(|r| r.set_routing).map(|_| partition.to_owned())
103}
104
105fn resolve_values(
112 fields: &[InjectedField],
113 partition: &str,
114) -> Result<Vec<(FieldName, Value)>, RequestError> {
115 fields
116 .iter()
117 .map(|field| {
118 let value = match &field.value {
119 InjectedValue::Constant(v) => v.clone(),
120 InjectedValue::PartitionId => Value::String(partition.to_owned()),
121 InjectedValue::FromPrincipal(_) | InjectedValue::FromHeader(_) => {
122 return Err(RequestError::Internal {
123 reason: "context-derived injected value reached the engine unresolved",
124 })
125 }
126 };
127 Ok((field.name.clone(), value))
128 })
129 .collect()
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use osproxy_core::{ClusterId, Epoch, IndexName, PartitionId, Target};
136 use osproxy_rewrite::RewriteError;
137 use osproxy_spi::{IdTemplate, Protocol, RouteDecision};
138 use serde_json::json;
139
140 fn resolved(transform: BodyTransform) -> Resolved {
141 Resolved {
142 partition: PartitionId::from("acme"),
143 decision: RouteDecision {
144 target: Target::new(ClusterId::from("eu-1"), IndexName::from("shared")),
145 upstream_protocol: Protocol::Http1,
146 header_ops: Vec::new(),
147 body_transform: transform,
148 epoch: Epoch::new(4),
149 },
150 migration: osproxy_spi::MigrationPhase::Settled,
151 }
152 }
153
154 fn index_op(batch: &WriteBatch) -> (&Option<String>, &Option<String>, Value) {
155 match &batch.ops()[0].doc {
156 DocOp::Index { id, routing, body } => {
157 (id, routing, serde_json::from_slice(body).unwrap())
158 }
159 DocOp::Create { .. } | DocOp::Update { .. } | DocOp::Delete { .. } => {
161 unreachable_delete()
162 }
163 }
164 }
165
166 fn unreachable_delete() -> (&'static Option<String>, &'static Option<String>, Value) {
169 (&None, &None, Value::Null)
170 }
171
172 #[test]
173 fn inject_and_construct_id_with_routing() {
174 let inject = vec![InjectedField::new(
175 FieldName::from("_tenant"),
176 InjectedValue::PartitionId,
177 )];
178 let id = DocIdRule::new(IdTemplate::new("{partition}:{body.id}")).with_routing(true);
179 let r = resolved(BodyTransform::Both { inject, id });
180 let batch = build_write_batch(&r, br#"{ "id": 1001, "msg": "hi" }"#).unwrap();
181
182 assert_eq!(batch.ops()[0].epoch, Epoch::new(4));
183 let (id, routing, body) = index_op(&batch);
184 assert_eq!(id.as_deref(), Some("acme:1001"));
185 assert_eq!(routing.as_deref(), Some("acme"));
186 assert_eq!(body["_tenant"], json!("acme"));
187 assert_eq!(body["msg"], json!("hi"));
188 }
189
190 #[test]
191 fn inject_only_has_no_id_or_routing() {
192 let inject = vec![InjectedField::new(
193 FieldName::from("_t"),
194 InjectedValue::Constant(json!("acme")),
195 )];
196 let r = resolved(BodyTransform::Inject(inject));
197 let batch = build_write_batch(&r, br#"{ "k": 1 }"#).unwrap();
198 let (id, routing, body) = index_op(&batch);
199 assert!(id.is_none());
200 assert!(routing.is_none());
201 assert_eq!(body["_t"], json!("acme"));
202 }
203
204 #[test]
205 fn construct_id_without_routing() {
206 let id = DocIdRule::new(IdTemplate::new("{partition}:{body.k}"));
207 let r = resolved(BodyTransform::ConstructId(id));
208 let batch = build_write_batch(&r, br#"{ "k": "v" }"#).unwrap();
209 let (id, routing, _) = index_op(&batch);
210 assert_eq!(id.as_deref(), Some("acme:v"));
211 assert!(routing.is_none());
212 }
213
214 #[test]
215 fn none_transform_passes_body_through() {
216 let r = resolved(BodyTransform::None);
217 let batch = build_write_batch(&r, br#"{ "k": 1 }"#).unwrap();
218 let (id, routing, body) = index_op(&batch);
219 assert!(id.is_none());
220 assert!(routing.is_none());
221 assert_eq!(body, json!({ "k": 1 }));
222 }
223
224 #[test]
225 fn reserved_field_collision_is_rejected() {
226 let inject = vec![InjectedField::new(
227 FieldName::from("_t"),
228 InjectedValue::Constant(json!("acme")),
229 )];
230 let r = resolved(BodyTransform::Inject(inject));
231 let err = build_write_batch(&r, br#"{ "_t": "evil" }"#).unwrap_err();
232 assert!(matches!(
233 err,
234 RequestError::Rewrite(RewriteError::ReservedFieldCollision { .. })
235 ));
236 }
237
238 #[test]
239 fn malformed_body_is_rejected() {
240 let r = resolved(BodyTransform::None);
241 let err = build_write_batch(&r, b"not json").unwrap_err();
242 assert!(matches!(
243 err,
244 RequestError::Rewrite(RewriteError::InvalidJson)
245 ));
246 }
247
248 #[test]
249 fn unresolved_principal_value_is_internal_error() {
250 let inject = vec![InjectedField::new(
251 FieldName::from("_t"),
252 InjectedValue::FromPrincipal("tenant".to_owned()),
253 )];
254 let r = resolved(BodyTransform::Inject(inject));
255 let err = build_write_batch(&r, br#"{ "k": 1 }"#).unwrap_err();
256 assert!(matches!(err, RequestError::Internal { .. }));
257 }
258}