1use std::collections::{BTreeMap, HashMap};
8
9use crate::manifest::{
10 DefaultPrivilege, DefaultPrivilegeGrant, Grant, MemberSpec, Membership, ObjectTarget,
11 PolicyManifest, RoleDefinition,
12};
13use crate::model::RoleGraph;
14
15pub fn role_graph_to_manifest(graph: &RoleGraph) -> PolicyManifest {
21 let roles: Vec<RoleDefinition> = graph
23 .roles
24 .iter()
25 .map(|(name, state)| {
26 let defaults = crate::model::RoleState::default();
27 RoleDefinition {
28 name: name.clone(),
29 login: if state.login != defaults.login {
30 Some(state.login)
31 } else {
32 None
33 },
34 superuser: if state.superuser != defaults.superuser {
35 Some(state.superuser)
36 } else {
37 None
38 },
39 createdb: if state.createdb != defaults.createdb {
40 Some(state.createdb)
41 } else {
42 None
43 },
44 createrole: if state.createrole != defaults.createrole {
45 Some(state.createrole)
46 } else {
47 None
48 },
49 inherit: if state.inherit != defaults.inherit {
50 Some(state.inherit)
51 } else {
52 None
53 },
54 replication: if state.replication != defaults.replication {
55 Some(state.replication)
56 } else {
57 None
58 },
59 bypassrls: if state.bypassrls != defaults.bypassrls {
60 Some(state.bypassrls)
61 } else {
62 None
63 },
64 connection_limit: if state.connection_limit != defaults.connection_limit {
65 Some(state.connection_limit)
66 } else {
67 None
68 },
69 comment: state.comment.clone(),
70 password: None, password_valid_until: state.password_valid_until.clone(),
72 }
73 })
74 .collect();
75
76 let grants: Vec<Grant> = graph
78 .grants
79 .iter()
80 .map(|(key, state)| Grant {
81 role: key.role.clone(),
82 privileges: state.privileges.iter().copied().collect(),
83 on: ObjectTarget {
84 object_type: key.object_type,
85 schema: key.schema.clone(),
86 name: key.name.clone(),
87 },
88 })
89 .collect();
90
91 let mut dp_groups: BTreeMap<(String, String), Vec<DefaultPrivilegeGrant>> = BTreeMap::new();
94 for (key, state) in &graph.default_privileges {
95 dp_groups
96 .entry((key.owner.clone(), key.schema.clone()))
97 .or_default()
98 .push(DefaultPrivilegeGrant {
99 role: Some(key.grantee.clone()),
100 privileges: state.privileges.iter().copied().collect(),
101 on_type: key.on_type,
102 });
103 }
104 let default_privileges: Vec<DefaultPrivilege> = dp_groups
105 .into_iter()
106 .map(|((owner, schema), grant)| DefaultPrivilege {
107 owner: Some(owner),
108 schema,
109 grant,
110 })
111 .collect();
112
113 let mut membership_map: BTreeMap<String, Vec<MemberSpec>> = BTreeMap::new();
116 for edge in &graph.memberships {
117 membership_map
118 .entry(edge.role.clone())
119 .or_default()
120 .push(MemberSpec {
121 name: edge.member.clone(),
122 inherit: edge.inherit,
123 admin: edge.admin,
124 });
125 }
126 let memberships: Vec<Membership> = membership_map
127 .into_iter()
128 .map(|(role, members)| Membership { role, members })
129 .collect();
130
131 PolicyManifest {
132 default_owner: None,
133 auth_providers: Vec::new(),
134 profiles: HashMap::new(),
135 schemas: Vec::new(),
136 roles,
137 grants,
138 default_privileges,
139 memberships,
140 retirements: Vec::new(),
141 }
142}
143
144#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::diff::diff;
152 use crate::manifest::{expand_manifest, parse_manifest};
153 use crate::model::RoleGraph;
154
155 #[test]
157 fn round_trip_export_import() {
158 let yaml = r#"
159default_owner: app_owner
160
161profiles:
162 editor:
163 grants:
164 - privileges: [USAGE]
165 on: { type: schema }
166 - privileges: [SELECT, INSERT, UPDATE, DELETE]
167 on: { type: table, name: "*" }
168 default_privileges:
169 - privileges: [SELECT, INSERT, UPDATE, DELETE]
170 on_type: table
171
172schemas:
173 - name: inventory
174 profiles: [editor]
175
176roles:
177 - name: analytics
178 login: true
179 comment: "Analytics role"
180
181memberships:
182 - role: inventory-editor
183 members:
184 - name: "user@example.com"
185 inherit: true
186"#;
187 let manifest = parse_manifest(yaml).unwrap();
188 let expanded = expand_manifest(&manifest).unwrap();
189 let original =
190 RoleGraph::from_expanded(&expanded, manifest.default_owner.as_deref()).unwrap();
191
192 let exported_manifest = role_graph_to_manifest(&original);
194 let exported_expanded = expand_manifest(&exported_manifest).unwrap();
195 let reimported = RoleGraph::from_expanded(
196 &exported_expanded,
197 exported_manifest.default_owner.as_deref(),
198 )
199 .unwrap();
200
201 let changes = diff(&original, &reimported);
203 assert!(
204 changes.is_empty(),
205 "round-trip produced unexpected changes: {changes:?}"
206 );
207 }
208
209 #[test]
210 fn export_only_emits_non_default_attributes() {
211 let yaml = r#"
212roles:
213 - name: basic-role
214 - name: login-role
215 login: true
216 connection_limit: 5
217"#;
218 let manifest = parse_manifest(yaml).unwrap();
219 let expanded = expand_manifest(&manifest).unwrap();
220 let graph = RoleGraph::from_expanded(&expanded, None).unwrap();
221
222 let exported = role_graph_to_manifest(&graph);
223 let basic = exported
224 .roles
225 .iter()
226 .find(|r| r.name == "basic-role")
227 .unwrap();
228 assert!(basic.login.is_none());
229 assert!(basic.superuser.is_none());
230 assert!(basic.connection_limit.is_none());
231
232 let login = exported
233 .roles
234 .iter()
235 .find(|r| r.name == "login-role")
236 .unwrap();
237 assert_eq!(login.login, Some(true));
238 assert_eq!(login.connection_limit, Some(5));
239 }
240
241 #[test]
242 fn exported_yaml_omits_null_fields() {
243 let yaml = r#"
244roles:
245 - name: basic-role
246 - name: login-role
247 login: true
248 connection_limit: 5
249"#;
250 let manifest = parse_manifest(yaml).unwrap();
251 let expanded = expand_manifest(&manifest).unwrap();
252 let graph = RoleGraph::from_expanded(&expanded, None).unwrap();
253
254 let exported = role_graph_to_manifest(&graph);
255 let serialized = serde_yaml::to_string(&exported).unwrap();
256
257 assert!(
258 !serialized.contains("null"),
259 "serialized YAML should not contain null fields, got:\n{serialized}"
260 );
261 assert!(serialized.contains("login: true"), "got:\n{serialized}");
263 assert!(
264 serialized.contains("connection_limit: 5"),
265 "got:\n{serialized}"
266 );
267 }
268}