update-ssh-keys 0.7.0

A tool for managing authorized SSH keys
Documentation
#!/usr/bin/env python2
# Copyright 2017 CoreOS, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pwd
import shutil
import subprocess
import tempfile
import unittest

script_path = os.path.abspath('%s/../../target/debug/update-ssh-keys' % __file__)

test_keys = {
    'valid1': 'ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDULTftpWMj4nD+7Ps'
              'B8itam2T6Aqm9Z+ursQG1SRiK4ie5rHGJoteGnbH91Uix/HDE5GC3Hz'
              'ICQVOnQay4hwJUKRfEUEWj1Sncer/BL2igDquABlcXNl2dgOlfJ8a3q'
              '6IZnQpdEe6Vrqg/Ui082UxuZ08pNV94M/5IhR2fx0EbY66PQ97o+ywH'
              'sB7oXDO8p/+mGL+h7cxFY7hILXTa5/3TGBEgcA65Rrmq22eiRt97RGh'
              'DjfzIqTqb8gwuhTSNN7FWDLrEyRwJMbaTgDSoMIZdLtndVrGEqFHUO+'
              'WzinSiEQCs2MDDnTk29bleHAEktu1x68GYhg9S7O/gZq8/swAV '
              'core@valid1',
    'valid2': 'command="echo \\"test\\"" ssh-dss AAAAB3NzaC1kc3MAAACBAJA94Sqw80BSKjVTNZD6570nXIN'
              'hP8R2UhbBuydT+GI6CfA9Dw7O0udJQUfrqARFcRQR/syc72CO6jaKNE'
              '3/A5E+8uVmRZt7s9VtA47s1qxqHswth74m1Nb86n2OTB0HcW63FsXo2'
              'cJF+r+l6F3IcRPi4z/eaEKG7uhAS59TjH2tAAAAFQC0I9kL3oceMT1O'
              '44WPe6NZ8w8CMwAAAIABGm2Yg8nGFZbo/W8njuM79w0W2P1NBVNWzBH'
              'WQqVbr4i1bWTSSc9X+itQUpeF6zAUDsUoprhNise2NLrMYCLFo9JxhE'
              'iYAcEJ/YbKEnjtJzaAmQNpyh3rCWuOcGPTevjAZIkl+zEc+/N7tCW1e'
              'uDYm6IXZ8LEQyTUQUdU4pZ2OgAAAIABk1ZA3+TiCMaoAafNVUZ7zwqk'
              '888yVOgsJ7HGGDGRMo5ytr2SUJB7QWsLX6Un/Zbu32nXsAqtqagxd6F'
              'Ies98TSekMh/hAv9uK92mEsXSINXOeIMKRedqOyPgk5IEOsFpxAUO4T'
              'xpYToeuM8HRemecxw2eIFHnax+mQqCsi7FgQ== core@valid2',
    'bad':    'ssh-bad this-not-a-key core@bad',
}

fingerprints = {
    'valid1': 'SHA256:yZ+o48h6quk9c+JVgJ/Zq4S5u4LUk6TSpneHKkmM9KY',
    'valid2': 'SHA256:RP5k1AybZ1kollIAnpUavr1v1nfZ0yloKvI46AMDPkM ',
}

class UpdateSshKeysTestCase(unittest.TestCase):

    def setUp(self):
        user_info = pwd.getpwuid(os.getuid())
        self.user = user_info.pw_name
        self.ssh_dir = tempfile.mkdtemp(prefix='test_update_ssh_keys')
        self.env = os.environ.copy()
        self.pub_files = {}

        for name, text in test_keys.iteritems():
            pub_path = '%s/%s.pub' % (self.ssh_dir, name)
            self.pub_files[name] = pub_path
            with open(pub_path, 'w') as pub_fd:
                pub_fd.write('%s\n' % text)

    def tearDown(self):
        shutil.rmtree(self.ssh_dir)

    def assertHasKeys(self, *keys):
        with open('%s/authorized_keys' % self.ssh_dir, 'r') as fd:
            text = fd.read()
        self.assertTrue(text.startswith('# auto-generated'))
        for key in keys:
            self.assertIn(test_keys[key], text)
        for key in test_keys:
            if key in keys:
                continue
            self.assertNotIn(test_keys[key], text)

    def run_script(self, *args, **kwargs):
        cmd = [script_path, '-u', self.user, '--ssh-dir', self.ssh_dir]
        cmd.extend(args)
        return subprocess.Popen(cmd, env=self.env,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                **kwargs)

    def test_usage(self):
        proc = self.run_script('-h')
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertTrue(out.startswith('Usage: '))
        self.assertEquals(err, '')

    def test_no_keys(self):
        proc = self.run_script()
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 1)
        self.assertEquals(out, '')
        self.assertIn('no keys found', err)
        self.assertTrue(os.path.isdir('%s/authorized_keys.d' % self.ssh_dir))
        self.assertFalse(os.path.exists('%s/authorized_keys' % self.ssh_dir))

    def test_first_run(self):
        with open('%s/authorized_keys' % self.ssh_dir, 'w') as fd:
            fd.write('%s\n' % test_keys['valid1'])
            fd.write('%s\n' % test_keys['valid2'])
        proc = self.run_script()
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertTrue(out.startswith('Updated '))
        self.assertEquals(err, '')
        self.assertTrue(os.path.exists(
                '%s/authorized_keys.d/old_authorized_keys' % self.ssh_dir))
        self.assertHasKeys('valid1', 'valid2')

    def test_add_one_file(self):
        proc = self.run_script('-a', 'one', self.pub_files['valid1'])
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertTrue(out.startswith('Adding'))
        self.assertIn(fingerprints['valid1'], out)
        self.assertIn('\nUpdated ', out)
        self.assertEquals(err, '')
        self.assertTrue(os.path.exists(
                '%s/authorized_keys.d/one' % self.ssh_dir))
        self.assertHasKeys('valid1')

    def test_add_one_stdin(self):
        proc = self.run_script('-a', 'one', stdin=subprocess.PIPE)
        out, err = proc.communicate(test_keys['valid1'])
        self.assertEquals(proc.returncode, 0)
        self.assertTrue(out.startswith('Adding'))
        self.assertIn(fingerprints['valid1'], out)
        self.assertIn('\nUpdated ', out)
        self.assertEquals(err, '')
        self.assertTrue(os.path.exists(
                '%s/authorized_keys.d/one' % self.ssh_dir))
        self.assertHasKeys('valid1')

    def test_replace_one(self):
        self.test_add_one_file()
        proc = self.run_script('-a', 'one', self.pub_files['valid2'])
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertIn(fingerprints['valid2'], out)
        self.assertEquals(err, '')
        self.assertHasKeys('valid2')

    def test_no_replace(self):
        self.test_add_one_file()
        proc = self.run_script('-n', '-a', 'one', self.pub_files['valid2'])
        out, err = proc.communicate()
        self.assertTrue(out.startswith('Skipping'))
        self.assertEquals(proc.returncode, 0)
        self.assertEquals(err, '')
        self.assertHasKeys('valid1')

        proc = self.run_script('-n', '-A', 'one', self.pub_files['valid2'])
        out, err = proc.communicate()
        self.assertTrue(out.startswith('Skipping'))
        self.assertEquals(proc.returncode, 0)
        self.assertEquals(err, '')
        self.assertHasKeys('valid1')

    def test_add_two(self):
        self.test_add_one_file()
        proc = self.run_script('-a', 'two', self.pub_files['valid2'])
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertIn(fingerprints['valid2'], out)
        self.assertEquals(err, '')
        self.assertHasKeys('valid1', 'valid2')

    def test_del_one(self):
        self.test_add_one_file()
        proc = self.run_script('-d', 'one')
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 1)
        self.assertIn(fingerprints['valid1'], out)
        self.assertIn('no keys found', err)
        # Removed from authorized_keys.d but not authorized_keys
        self.assertFalse(os.path.exists(
                '%s/authorized_keys.d/one' % self.ssh_dir))
        self.assertHasKeys('valid1')

    def test_del_two(self):
        self.test_add_two()
        proc = self.run_script('-d', 'two')
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertIn(fingerprints['valid2'], out)
        self.assertEquals(err, '')
        self.assertHasKeys('valid1')

    def test_disable(self):
        self.test_add_two()
        proc = self.run_script('-D', 'two')
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertTrue(out.startswith('Disabling'))
        self.assertIn(fingerprints['valid2'], out)
        self.assertEquals(err, '')
        self.assertHasKeys('valid1')

        proc = self.run_script('-a', 'two', self.pub_files['valid2'])
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertTrue(out.startswith('Skipping'))
        self.assertEquals(err, '')
        self.assertHasKeys('valid1')

    def test_enable(self):
        self.test_disable()
        proc = self.run_script('-A', 'two', self.pub_files['valid2'])
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertTrue(out.startswith('Adding'))
        self.assertEquals(err, '')
        self.assertHasKeys('valid1', 'valid2')

    def test_add_bad(self):
        self.test_add_one_file()
        proc = self.run_script('-a', 'bad', self.pub_files['bad'])
        out, err = proc.communicate()
        self.assertEquals(proc.returncode, 0)
        self.assertIn('warning', out)
        self.assertIn('failed to parse public key', out)
        self.assertHasKeys('valid1')


if __name__ == '__main__':
    unittest.main()